diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml index f9bbc83..d3548c4 100644 --- a/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml +++ b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml @@ -12,8 +12,8 @@ seed: 1234 learning_rate: 0.0002 batch_size: 8 test_every: 2000 -save_every: 5000 -max_iterations: 2000000 +save_every: 10000 +max_iterations: 3000000 sigma: 1.0 n_flows: 8 diff --git a/parakeet/models/waveflow/slurm.py b/parakeet/models/waveflow/slurm.py deleted file mode 100644 index de1818c..0000000 --- a/parakeet/models/waveflow/slurm.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Utility module for restarting training when using SLURM. -""" -import subprocess -import os -import sys -import shlex -import re -import time - - -def job_info(): - """Get information about the current job using `scontrol show job`. - Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to - their values, both as strings. - """ - job_id = int(os.environ["SLURM_JOB_ID"]) - - command = ["scontrol", "show", "job", str(job_id)] - output = subprocess.check_output(command).decode("utf-8") - - # Use a regex to extract the parameter names and values - pattern = "([A-Za-z/]*)=([^ \t\n]*)" - return dict(re.findall(pattern, output)) - - -def parse_hours(text): - """Parse a time format HH or DD-HH into a number of hours.""" - hour_chunks = text.split("-") - if len(hour_chunks) == 1: - return int(hour_chunks[0]) - elif len(hour_chunks) == 2: - return 24 * int(hour_chunks[0]) + int(hour_chunks[1]) - else: - raise ValueError("Unexpected hour format (expected HH or " - "DD-HH, but got {}).".format(text)) - - -def parse_time(text): - """Convert slurm time to an integer. - Expects time to be of the form: - "hours:minutes:seconds" or "day-hours:minutes:seconds". - """ - hours, minutes, seconds = text.split(":") - try: - return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds) - except ValueError as e: - raise ValueError("Error parsing time {}. Got error {}.".format( - text, str(e))) - - -def restart_command(): - """Using the environment and SLURM command, create a command that, when, - run, will enqueue a repeat of the current job using `sbatch`. - Return the command as a list of strings, suitable for passing to - `subprocess.check_call` or similar functions. - Returns: - resume_command: list, command to run to restart job. - end_time: int or None; the time the job will end or None - if the job has unlimited runtime. - """ - # Make sure `RunTime` could be parsed correctly. - while job_info()["RunTime"] == "INVALID": - time.sleep(1) - - # Get all the necessary information by querying SLURM with this job id - info = job_info() - - try: - num_cpus = int(info["CPUs/Task"]) - except KeyError: - num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"]) - - num_tasks = int(os.environ["SLURM_NTASKS"]) - nodes = info["NumNodes"] - gres, partition = info.get("Gres"), info.get("Partition") - stderr, stdout = info.get("StdErr"), info.get("StdOut") - job_name = info.get("JobName") - command = ["sbatch", "--job-name={}".format(job_name), - "--ntasks={}".format(num_tasks), - "--exclude=asimov-186"] - - if partition: - command.extend(["--partition", partition]) - - if gres and gres != "(null)": - command.extend(["--gres", gres]) - num_gpu = int(gres.split(':')[-1]) - print("number of gpu assigned by slurm is {}".format(num_gpu)) - - if stderr: - command.extend(["--error", stderr]) - - if stdout: - command.extend(["--output", stdout]) - - python = subprocess.check_output( - ["/usr/bin/which", "python3"]).decode("utf-8").strip() - dist_setting = ['-m', 'paddle.distributed.launch'] - wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv - - command.append( - "--wrap={}".format(" ".join(shlex.quote(arg) for arg in wrap_cmd))) - time_limit_string = info["TimeLimit"] - if time_limit_string.lower() == "unlimited": - print("UNLIMITED detected: restart OFF, infinite learning ON.", - flush=True) - return command, None - time_limit = parse_time(time_limit_string) - runtime = parse_time(info["RunTime"]) - end_time = time.time() + time_limit - runtime - - return command, end_time