Parakeet/parakeet/utils/mp_tools.py

22 lines
360 B
Python
Raw Normal View History

2020-12-01 18:13:30 +08:00
import paddle
from paddle import distributed as dist
from functools import wraps
__all__ = ["rank_zero_only"]
2020-12-01 18:13:30 +08:00
def rank_zero_only(func):
local_rank = dist.get_rank()
@wraps(func)
def wrapper(*args, **kwargs):
if local_rank != 0:
return
result = func(*args, **kwargs)
return result
return wrapper