ParakeetRebeccaRosario/parakeet/utils/mp_tools.py

19 lines
329 B
Python
Raw Normal View History

2020-12-01 18:13:30 +08:00
import paddle
from paddle import distributed as dist
from functools import wraps
def rank_zero_only(func):
local_rank = dist.get_rank()
@wraps(func)
def wrapper(*args, **kwargs):
if local_rank != 0:
return
result = func(*args, **kwargs)
return result
return wrapper