Parakeet/parakeet/utils/mp_tools.py

19 lines
329 B
Python

import paddle
from paddle import distributed as dist
from functools import wraps
def rank_zero_only(func):
local_rank = dist.get_rank()
@wraps(func)
def wrapper(*args, **kwargs):
if local_rank != 0:
return
result = func(*args, **kwargs)
return result
return wrapper