add layer_tools

This commit is contained in:
chenfeiyu 2020-02-13 02:27:04 +00:00 committed by liuyibing01
parent 155dfe633d
commit 6cbcadebda
1 changed files with 35 additions and 0 deletions

View File

@ -0,0 +1,35 @@
import numpy as np
from torch import nn
import paddle.fluid.dygraph as dg
def summary(layer):
num_params = num_elements = 0
print("layer summary:")
for name, param in layer.state_dict().items():
print("{}|{}|{}".format(name, param.shape, np.prod(param.shape)))
num_elements += np.prod(param.shape)
num_params += 1
print("layer has {} parameters, {} elements.".format(
num_params, num_elements))
def freeze(layer):
for param in layer.parameters():
param.trainable = False
def unfreeze(layer):
for param in layer.parameters():
param.trainable = True
def torch_summary(layer):
num_params = num_elements = 0
print("layer summary:")
for name, param in layer.named_parameters():
print("{}|{}|{}".format(name, param.shape, np.prod(param.shape)))
num_elements += np.prod(param.shape)
num_params += 1
print("layer has {} parameters, {} elements.".format(
num_params, num_elements))