From 6cbcadebdae01b2c451f466b5d45d0658b859068 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 13 Feb 2020 02:27:04 +0000 Subject: [PATCH] add layer_tools --- parakeet/utils/layer_tools.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 parakeet/utils/layer_tools.py diff --git a/parakeet/utils/layer_tools.py b/parakeet/utils/layer_tools.py new file mode 100644 index 0000000..eaa9c9e --- /dev/null +++ b/parakeet/utils/layer_tools.py @@ -0,0 +1,35 @@ +import numpy as np +from torch import nn +import paddle.fluid.dygraph as dg + + +def summary(layer): + num_params = num_elements = 0 + print("layer summary:") + for name, param in layer.state_dict().items(): + print("{}|{}|{}".format(name, param.shape, np.prod(param.shape))) + num_elements += np.prod(param.shape) + num_params += 1 + print("layer has {} parameters, {} elements.".format( + num_params, num_elements)) + + +def freeze(layer): + for param in layer.parameters(): + param.trainable = False + + +def unfreeze(layer): + for param in layer.parameters(): + param.trainable = True + + +def torch_summary(layer): + num_params = num_elements = 0 + print("layer summary:") + for name, param in layer.named_parameters(): + print("{}|{}|{}".format(name, param.shape, np.prod(param.shape))) + num_elements += np.prod(param.shape) + num_params += 1 + print("layer has {} parameters, {} elements.".format( + num_params, num_elements))