diff --git a/.gitignore b/.gitignore index 7906666..af9563d 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,5 @@ dmypy.json *.swp runs syn_audios +exp/ +dump/ diff --git a/examples/parallelwave_gan/baker/batch_fn.py b/examples/parallelwave_gan/baker/batch_fn.py new file mode 100644 index 0000000..71761e9 --- /dev/null +++ b/examples/parallelwave_gan/baker/batch_fn.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + + +class Clip(object): + """Collate functor for training vocoders. + """ + + def __init__( + self, + batch_max_steps=20480, + hop_size=256, + aux_context_window=0, ): + """Initialize customized collater for PyTorch DataLoader. + + Args: + batch_max_steps (int): The maximum length of input signal in batch. + hop_size (int): Hop size of auxiliary features. + aux_context_window (int): Context window size for auxiliary feature conv. + + """ + if batch_max_steps % hop_size != 0: + batch_max_steps += -(batch_max_steps % hop_size) + assert batch_max_steps % hop_size == 0 + self.batch_max_steps = batch_max_steps + self.batch_max_frames = batch_max_steps // hop_size + self.hop_size = hop_size + self.aux_context_window = aux_context_window + + # set useful values in random cutting + self.start_offset = aux_context_window + self.end_offset = -(self.batch_max_frames + aux_context_window) + self.mel_threshold = self.batch_max_frames + 2 * aux_context_window + + def __call__(self, examples): + """Convert into batch tensors. + + Args: + batch (list): list of tuple of the pair of audio and features. + + Returns: + Tensor: Auxiliary feature batch (B, C, T'), where + T = (T' - 2 * aux_context_window) * hop_size. + Tensor: Target signal batch (B, 1, T). + + """ + # check length + examples = [ + self._adjust_length(*b) for b in examples + if len(b[1]) > self.mel_threshold + ] + xs, cs = [b[0] for b in examples], [b[1] for b in examples] + + # make batch with random cut + c_lengths = [len(c) for c in cs] + start_frames = np.array([ + np.random.randint(self.start_offset, cl + self.end_offset) + for cl in c_lengths + ]) + x_starts = start_frames * self.hop_size + x_ends = x_starts + self.batch_max_steps + + c_starts = start_frames - self.aux_context_window + c_ends = start_frames + self.batch_max_frames + self.aux_context_window + y_batch = [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)] + c_batch = [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)] + + # convert each batch to tensor, asuume that each item in batch has the same length + y_batch = paddle.to_tensor( + y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T) + c_batch = paddle.to_tensor( + c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T') + + return (c_batch, ), y_batch + + def _adjust_length(self, x, c): + """Adjust the audio and feature lengths. + + Note: + Basically we assume that the length of x and c are adjusted + through preprocessing stage, but if we use other library processed + features, this process will be needed. + + """ + if len(x) < len(c) * self.hop_size: + x = np.pad(x, (0, len(c) * self.hop_size - len(x)), mode="edge") + + # check the legnth is valid + assert len(x) == len(c) * self.hop_size + + return x, c diff --git a/examples/parallelwave_gan/baker/conf/default.yaml b/examples/parallelwave_gan/baker/conf/default.yaml new file mode 100644 index 0000000..ee888a2 --- /dev/null +++ b/examples/parallelwave_gan/baker/conf/default.yaml @@ -0,0 +1,127 @@ +# This is the hyperparameter configuration file for Parallel WaveGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires 12 GB GPU memory and takes ~3 days on RTX TITAN. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +sr: 24000 # Sampling rate. +n_fft: 2048 # FFT size. +hop_length: 300 # Hop size. +win_length: 1200 # Window length. + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. +fmax: 7600 # Maximum frequency in mel basis calculation. +# global_gain_scale: 1.0 # Will be multiplied to all of waveform. +trim_silence: false # Whether to trim the start and end of silence. +top_db: 60 # Need to tune carefully if the recording is not good. +trim_frame_length: 2048 # Frame size in trimming. +trim_hop_length: 512 # Hop size in trimming. +# format: "npy" # Feature file format. "npy" or "hdf5" is supported. + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Kernel size of dilated convolution. + layers: 30 # Number of residual block layers. + stacks: 3 # Number of stacks i.e., dilation cycles. + residual_channels: 64 # Number of channels in residual conv. + gate_channels: 128 # Number of channels in gated conv. + skip_channels: 64 # Number of channels in skip conv. + aux_channels: 80 # Number of channels for auxiliary feature conv. + # Must be the same as num_mels. + aux_context_window: 2 # Context window size for auxiliary feature. + # If set to 2, previous 2 and future 2 frames will be considered. + dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. + bias: true # use bias in residual blocks + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + use_causal_conv: false # use causal conv in residual blocks and upsample layers + # upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture. + upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size. + interpolate_mode: "nearest" # upsample net interpolate mode + freq_axis_kernel_size: 1 # upsamling net: convolution kernel size in frequencey axis + nonlinear_activation: null + nonlinear_activation_params: {} + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Number of output channels. + layers: 10 # Number of conv layers. + conv_channels: 64 # Number of chnn layers. + bias: true # Whether to use bias parameter in conv. + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv. + nonlinear_activation_params: # Nonlinear function parameters + negative_slope: 0.2 # Alpha in LeakyReLU. + +########################################################### +# STFT LOSS SETTING # +########################################################### +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 4.0 # Loss balancing coefficient. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 6 # Batch size. +batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size. +pin_memory: true # Whether to pin memory in Pytorch DataLoader. +num_workers: 2 # Number of workers in Pytorch DataLoader. +remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. +allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-6 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 0.0001 # Generator's learning rate. + step_size: 200000 # Generator's scheduler step size. + gamma: 0.5 # Generator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +generator_grad_norm: 10 # Generator's gradient norm. +discriminator_optimizer_params: + epsilon: 1.0e-6 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 0.00005 # Discriminator's learning rate. + step_size: 200000 # Discriminator's scheduler step size. + gamma: 0.5 # Discriminator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +discriminator_grad_norm: 1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator. +train_max_steps: 400000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. +log_interval_steps: 100 # Interval steps to record the training log. + +########################################################### +# OTHER SETTING # +########################################################### +num_save_intermediate_results: 4 # Number of results to be saved as intermediate results. diff --git a/examples/parallelwave_gan/baker/config.py b/examples/parallelwave_gan/baker/config.py new file mode 100644 index 0000000..f555791 --- /dev/null +++ b/examples/parallelwave_gan/baker/config.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml +from yacs.config import CfgNode as Configuration + +with open("conf/default.yaml", 'rt') as f: + _C = yaml.safe_load(f) + _C = Configuration(_C) + + +def get_cfg_default(): + config = _C.clone() + return config diff --git a/examples/parallelwave_gan/baker/preprocess.py b/examples/parallelwave_gan/baker/preprocess.py new file mode 100644 index 0000000..9bf9623 --- /dev/null +++ b/examples/parallelwave_gan/baker/preprocess.py @@ -0,0 +1,280 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Dict, Any +import soundfile as sf +import librosa +import numpy as np +from config import get_cfg_default +import argparse +import yaml +import json +import dacite +import dataclasses +import concurrent.futures +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from pathlib import Path +import tqdm +from operator import itemgetter +from praatio import tgio +import logging + + +def logmelfilterbank(audio, + sr, + n_fft=1024, + hop_length=256, + win_length=None, + window="hann", + n_mels=80, + fmin=None, + fmax=None, + eps=1e-10): + """Compute log-Mel filterbank feature. + + Parameters + ---------- + audio : ndarray + Audio signal (T,). + sr : int + Sampling rate. + n_fft : int + FFT size. (Default value = 1024) + hop_length : int + Hop size. (Default value = 256) + win_length : int + Window length. If set to None, it will be the same as fft_size. (Default value = None) + window : str + Window function type. (Default value = "hann") + n_mels : int + Number of mel basis. (Default value = 80) + fmin : int + Minimum frequency in mel basis calculation. (Default value = None) + fmax : int + Maximum frequency in mel basis calculation. (Default value = None) + eps : float + Epsilon value to avoid inf in log calculation. (Default value = 1e-10) + + Returns + ------- + np.ndarray + Log Mel filterbank feature (#frames, num_mels). + + """ + # get amplitude spectrogram + x_stft = librosa.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + pad_mode="reflect") + spc = np.abs(x_stft) # (#bins, #frames,) + + # get mel basis + fmin = 0 if fmin is None else fmin + fmax = sr / 2 if fmax is None else fmax + mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax) + + return np.log10(np.maximum(eps, np.dot(mel_basis, spc))) + + +def process_sentence(config: Dict[str, Any], + fp: Path, + alignment_fp: Path, + output_dir: Path): + utt_id = fp.stem + + # reading + y, sr = librosa.load(fp, sr=config.sr) # resampling may occur + assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(y).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + duration = librosa.get_duration(y, sr=sr) + + # trim according to the alignment file + alignment = tgio.openTextgrid(alignment_fp) + intervals = alignment.tierDict[alignment.tierNameList[0]].entryList + first, last = intervals[0], intervals[-1] + start = 0 + end = last.end + if first.label == "sil" and first.end < duration: + start = first.end + else: + logging.warning( + f" There is something wrong with the fisrt interval {first} in utterance: {utt_id}" + ) + if last.label == "sil" and last.start < duration: + end = last.start + else: + end = duration + logging.warning( + f" There is something wrong with the last interval {last} in utterance: {utt_id}" + ) + # silence trimmed + start, end = librosa.time_to_samples([first.end, last.start], sr=sr) + y = y[start:end] + + # energy based silence trimming + if config.trim_silence: + y, _ = librosa.effects.trim( + y, + top_db=config.top_db, + frame_length=config.trim_frame_length, + hop_length=config.trim_hop_length) + + logmel = logmelfilterbank( + y, + sr=sr, + n_fft=config.n_fft, + window=config.window, + win_length=config.win_length, + hop_length=config.hop_length, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + + # adjust time to make num_samples == num_frames * hop_length + num_frames = logmel.shape[1] + y = np.pad(y, (0, config.n_fft), mode="reflect") + y = y[:num_frames * config.hop_length] + num_sample = y.shape[0] + + mel_path = output_dir / (utt_id + "_feats.npy") + wav_path = output_dir / (utt_id + "_wave.npy") + np.save(wav_path, y) + np.save(mel_path, logmel) + record = { + "utt_id": utt_id, + "num_samples": num_sample, + "num_frames": num_frames, + "feats_path": str(mel_path.resolve()), + "wave_path": str(wav_path.resolve()), + } + return record + + +def process_sentences(config, + fps: List[Path], + alignment_fps: List[Path], + output_dir: Path, + nprocs: int=1): + if nprocs == 1: + results = [] + for fp, alignment_fp in tqdm.tqdm(zip(fps, alignment_fps)): + results.append( + process_sentence(config, fp, alignment_fp, output_dir)) + else: + with ProcessPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp, alignment_fp in zip(fps, alignment_fps): + future = pool.submit(process_sentence, config, fp, + alignment_fp, output_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + results.append(ft.result()) + + results.sort(key=itemgetter("utt_id")) + with open(output_dir / "metadata.json", 'wt') as f: + json.dump(results, f) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)." + ) + parser.add_argument( + "--rootdir", + default=None, + type=str, + help="directory including wav files. you need to specify either scp or rootdir." + ) + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + parser.add_argument( + "--num_cpu", type=int, default=1, help="number of process.") + args = parser.parse_args() + + C = get_cfg_default() + if args.config: + C.merge_from_file(args.config) + C.freeze() + + if args.verbose > 1: + print(vars(args)) + print(yaml.dump(dataclasses.asdict(C))) + + root_dir = Path(args.rootdir) + dumpdir = Path(args.dumpdir) + dumpdir.mkdir(parents=True, exist_ok=True) + + wav_files = sorted(list((root_dir / "Wave").rglob("*.wav"))) + alignment_files = sorted( + list((root_dir / "PhoneLabeling").rglob("*.interval"))) + + # split data into 3 sections + train_wav_files = wav_files[:9800] + dev_wav_files = wav_files[9800:9900] + test_wav_files = wav_files[9900:] + + train_alignment_files = alignment_files[:9800] + dev_alignment_files = alignment_files[9800:9900] + test_alignment_files = alignment_files[9900:] + + train_dump_dir = dumpdir / "train" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # process for the 3 sections + process_sentences( + C, + train_wav_files, + train_alignment_files, + train_dump_dir, + nprocs=args.num_cpu) + process_sentences( + C, + dev_wav_files, + dev_alignment_files, + dev_dump_dir, + nprocs=args.num_cpu) + process_sentences( + C, + test_wav_files, + test_alignment_files, + test_dump_dir, + nprocs=args.num_cpu) + + +if __name__ == "__main__": + main() diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py new file mode 100644 index 0000000..bed3e21 --- /dev/null +++ b/examples/parallelwave_gan/baker/train.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import logging +import argparse +import dataclasses +from pathlib import Path + +import yaml +import dacite +import json +import paddle +import numpy as np +from paddle import nn +from paddle.nn import functional as F +from paddle import distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler +from paddle.optimizer import Adam # No RAdaom +from paddle.optimizer.lr import StepDecay +from paddle import DataParallel +from visualdl import LogWriter + +from parakeet.datasets.data_table import DataTable +from parakeet.training.updater import UpdaterBase +from parakeet.training.trainer import Trainer +from parakeet.training.reporter import report +from parakeet.training.checkpoint import KBest, KLatest +from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator +from parakeet.modules.stft_loss import MultiResolutionSTFTLoss + +from config import get_cfg_default + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + if not paddle.is_compiled_with_cuda: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # construct dataset for training and validation + with open(args.train_metadata) as f: + train_metadata = json.load(f) + train_dataset = DataTable( + data=train_metadata, + fields=["wave_path", "feats_path"], + converters={ + "wave_path": np.load, + "feats_path": np.load, + }, ) + with open(args.dev_metadata) as f: + dev_metadata = json.load(f) + dev_dataset = DataTable( + data=dev_metadata, + fields=["wave_path", "feats_path"], + converters={ + "wave_path": np.load, + "feats_path": np.load, + }, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=None, # TODO(defaine collate fn) + num_workers=4) + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=None, # TODO(defaine collate fn) + num_workers=4) + + generator = PWGGenerator(**config["generator_params"]) + discriminator = PWGDiscriminator(**config["discriminator_params"]) + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"]) + criterion_mse = nn.MSELoss() + lr_schedule_g = StepDecay(**config["generator_scheduler_params"]) + optimizer_g = Adam( + lr_schedule_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"]) + optimizer_d = Adam( + lr_schedule_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + + output_dir = Path(args.output_dir) + log_writer = None + if dist.get_rank() == 0: + output_dir.mkdir(parents=True, exist_ok=True) + log_writer = LogWriter(output_dir) + + # training loop + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN " + "model with Baker Mandrin TTS dataset.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config") + parser.add_argument("--train-metadata", type=str, help="training data") + parser.add_argument("--dev-metadata", type=str, help="dev data") + parser.add_argument("--output-dir", type=str, help="output dir") + parser.add_argument( + "--nprocs", type=int, default=1, help="number of processes") + parser.add_argument("--verbose", type=int, default=1, help="verbose") + + args = parser.parse_args() + config = get_cfg_default() + if args.config: + config.merge_from_file(args.config) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.nprocs > 1: + dist.spawn(train_sp, (args, config), nprocs=args.nprocs) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/parakeet/models/parallel_wavegan.py b/parakeet/models/parallel_wavegan.py index 00a48a4..e7d202d 100644 --- a/parakeet/models/parallel_wavegan.py +++ b/parakeet/models/parallel_wavegan.py @@ -109,10 +109,11 @@ class UpsampleNet(nn.Layer): padding = (freq_axis_padding, scale) conv = nn.Conv2D( 1, 1, kernel_size, padding=padding, bias_attr=False) + self.up_layers.extend([stretch, conv]) if nonlinear_activation is not None: nonlinear = getattr( nn, nonlinear_activation)(**nonlinear_activation_params) - self.up_layers.extend([stretch, conv, nonlinear]) + self.up_layers.append(nonlinear) def forward(self, c: Tensor) -> Tensor: """