对齐静态图时使用random数据
This commit is contained in:
parent
8626855142
commit
8f1b9c8b0d
|
@ -94,13 +94,11 @@ def check_static():
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from tools import program
|
from tools import program
|
||||||
|
|
||||||
config = program.load_config('configs/det/det_r50_vd_db.yml')
|
config = program.load_config('configs/rec/rec_r34_vd_none_bilstm_ctc.yml')
|
||||||
|
|
||||||
# import cv2
|
|
||||||
# data = cv2.imread('doc/imgs/1.jpg')
|
|
||||||
# data = normalize(data)
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
data = np.zeros((1, 3, 640, 640), dtype=np.float32)
|
np.random.seed(0)
|
||||||
|
data = np.random.rand(1, 3, 32, 320).astype(np.float32)
|
||||||
paddle.disable_static()
|
paddle.disable_static()
|
||||||
|
|
||||||
config['Architecture']['in_channels'] = 3
|
config['Architecture']['in_channels'] = 3
|
||||||
|
@ -110,17 +108,15 @@ def check_static():
|
||||||
load_dygraph_pretrain(
|
load_dygraph_pretrain(
|
||||||
model,
|
model,
|
||||||
logger,
|
logger,
|
||||||
'/Users/zhoujun20/Desktop/code/PaddleOCR/db/db',
|
'/Users/zhoujun20/Desktop/code/PaddleOCR/cnn_ctc/cnn_ctc',
|
||||||
load_static_weights=True)
|
load_static_weights=True)
|
||||||
x = paddle.to_variable(data)
|
x = paddle.to_tensor(data)
|
||||||
y = model(x)
|
y = model(x)
|
||||||
for y1 in y:
|
for y1 in y:
|
||||||
print(y1.shape)
|
print(y1.shape)
|
||||||
#
|
|
||||||
# # from matplotlib import pyplot as plt
|
static_out = np.load(
|
||||||
# # plt.imshow(y.numpy())
|
'/Users/zhoujun20/Desktop/code/PaddleOCR/output/conv.npy')
|
||||||
# # plt.show()
|
|
||||||
static_out = np.load('/Users/zhoujun20/Desktop/code/PaddleOCR/db/db.npy')
|
|
||||||
diff = y.numpy() - static_out
|
diff = y.numpy() - static_out
|
||||||
print(y.shape, static_out.shape, diff.mean())
|
print(y.shape, static_out.shape, diff.mean())
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue