add layer_tools
This commit is contained in:
parent
155dfe633d
commit
6cbcadebda
|
@ -0,0 +1,35 @@
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
import paddle.fluid.dygraph as dg
|
||||||
|
|
||||||
|
|
||||||
|
def summary(layer):
|
||||||
|
num_params = num_elements = 0
|
||||||
|
print("layer summary:")
|
||||||
|
for name, param in layer.state_dict().items():
|
||||||
|
print("{}|{}|{}".format(name, param.shape, np.prod(param.shape)))
|
||||||
|
num_elements += np.prod(param.shape)
|
||||||
|
num_params += 1
|
||||||
|
print("layer has {} parameters, {} elements.".format(
|
||||||
|
num_params, num_elements))
|
||||||
|
|
||||||
|
|
||||||
|
def freeze(layer):
|
||||||
|
for param in layer.parameters():
|
||||||
|
param.trainable = False
|
||||||
|
|
||||||
|
|
||||||
|
def unfreeze(layer):
|
||||||
|
for param in layer.parameters():
|
||||||
|
param.trainable = True
|
||||||
|
|
||||||
|
|
||||||
|
def torch_summary(layer):
|
||||||
|
num_params = num_elements = 0
|
||||||
|
print("layer summary:")
|
||||||
|
for name, param in layer.named_parameters():
|
||||||
|
print("{}|{}|{}".format(name, param.shape, np.prod(param.shape)))
|
||||||
|
num_elements += np.prod(param.shape)
|
||||||
|
num_params += 1
|
||||||
|
print("layer has {} parameters, {} elements.".format(
|
||||||
|
num_params, num_elements))
|
Loading…
Reference in New Issue