small change
This commit is contained in:
parent
4af577ad72
commit
91ab2b34c4
|
@ -12,8 +12,8 @@ seed: 1234
|
|||
learning_rate: 0.0002
|
||||
batch_size: 8
|
||||
test_every: 2000
|
||||
save_every: 5000
|
||||
max_iterations: 2000000
|
||||
save_every: 10000
|
||||
max_iterations: 3000000
|
||||
|
||||
sigma: 1.0
|
||||
n_flows: 8
|
||||
|
|
|
@ -1,113 +0,0 @@
|
|||
"""
|
||||
Utility module for restarting training when using SLURM.
|
||||
"""
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
import re
|
||||
import time
|
||||
|
||||
|
||||
def job_info():
|
||||
"""Get information about the current job using `scontrol show job`.
|
||||
Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to
|
||||
their values, both as strings.
|
||||
"""
|
||||
job_id = int(os.environ["SLURM_JOB_ID"])
|
||||
|
||||
command = ["scontrol", "show", "job", str(job_id)]
|
||||
output = subprocess.check_output(command).decode("utf-8")
|
||||
|
||||
# Use a regex to extract the parameter names and values
|
||||
pattern = "([A-Za-z/]*)=([^ \t\n]*)"
|
||||
return dict(re.findall(pattern, output))
|
||||
|
||||
|
||||
def parse_hours(text):
|
||||
"""Parse a time format HH or DD-HH into a number of hours."""
|
||||
hour_chunks = text.split("-")
|
||||
if len(hour_chunks) == 1:
|
||||
return int(hour_chunks[0])
|
||||
elif len(hour_chunks) == 2:
|
||||
return 24 * int(hour_chunks[0]) + int(hour_chunks[1])
|
||||
else:
|
||||
raise ValueError("Unexpected hour format (expected HH or "
|
||||
"DD-HH, but got {}).".format(text))
|
||||
|
||||
|
||||
def parse_time(text):
|
||||
"""Convert slurm time to an integer.
|
||||
Expects time to be of the form:
|
||||
"hours:minutes:seconds" or "day-hours:minutes:seconds".
|
||||
"""
|
||||
hours, minutes, seconds = text.split(":")
|
||||
try:
|
||||
return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds)
|
||||
except ValueError as e:
|
||||
raise ValueError("Error parsing time {}. Got error {}.".format(
|
||||
text, str(e)))
|
||||
|
||||
|
||||
def restart_command():
|
||||
"""Using the environment and SLURM command, create a command that, when,
|
||||
run, will enqueue a repeat of the current job using `sbatch`.
|
||||
Return the command as a list of strings, suitable for passing to
|
||||
`subprocess.check_call` or similar functions.
|
||||
Returns:
|
||||
resume_command: list<str>, command to run to restart job.
|
||||
end_time: int or None; the time the job will end or None
|
||||
if the job has unlimited runtime.
|
||||
"""
|
||||
# Make sure `RunTime` could be parsed correctly.
|
||||
while job_info()["RunTime"] == "INVALID":
|
||||
time.sleep(1)
|
||||
|
||||
# Get all the necessary information by querying SLURM with this job id
|
||||
info = job_info()
|
||||
|
||||
try:
|
||||
num_cpus = int(info["CPUs/Task"])
|
||||
except KeyError:
|
||||
num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"])
|
||||
|
||||
num_tasks = int(os.environ["SLURM_NTASKS"])
|
||||
nodes = info["NumNodes"]
|
||||
gres, partition = info.get("Gres"), info.get("Partition")
|
||||
stderr, stdout = info.get("StdErr"), info.get("StdOut")
|
||||
job_name = info.get("JobName")
|
||||
command = ["sbatch", "--job-name={}".format(job_name),
|
||||
"--ntasks={}".format(num_tasks),
|
||||
"--exclude=asimov-186"]
|
||||
|
||||
if partition:
|
||||
command.extend(["--partition", partition])
|
||||
|
||||
if gres and gres != "(null)":
|
||||
command.extend(["--gres", gres])
|
||||
num_gpu = int(gres.split(':')[-1])
|
||||
print("number of gpu assigned by slurm is {}".format(num_gpu))
|
||||
|
||||
if stderr:
|
||||
command.extend(["--error", stderr])
|
||||
|
||||
if stdout:
|
||||
command.extend(["--output", stdout])
|
||||
|
||||
python = subprocess.check_output(
|
||||
["/usr/bin/which", "python3"]).decode("utf-8").strip()
|
||||
dist_setting = ['-m', 'paddle.distributed.launch']
|
||||
wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv
|
||||
|
||||
command.append(
|
||||
"--wrap={}".format(" ".join(shlex.quote(arg) for arg in wrap_cmd)))
|
||||
time_limit_string = info["TimeLimit"]
|
||||
if time_limit_string.lower() == "unlimited":
|
||||
print("UNLIMITED detected: restart OFF, infinite learning ON.",
|
||||
flush=True)
|
||||
return command, None
|
||||
time_limit = parse_time(time_limit_string)
|
||||
runtime = parse_time(info["RunTime"])
|
||||
end_time = time.time() + time_limit - runtime
|
||||
|
||||
return command, end_time
|
Loading…
Reference in New Issue