2020-12-01 18:13:30 +08:00
|
|
|
import paddle
|
|
|
|
from paddle import distributed as dist
|
|
|
|
from functools import wraps
|
|
|
|
|
2020-12-09 15:58:39 +08:00
|
|
|
__all__ = ["rank_zero_only"]
|
|
|
|
|
|
|
|
|
2020-12-01 18:13:30 +08:00
|
|
|
def rank_zero_only(func):
|
|
|
|
local_rank = dist.get_rank()
|
|
|
|
|
|
|
|
@wraps(func)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
if local_rank != 0:
|
|
|
|
return
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
return result
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|