Paddle的BatchNorm在组网时beta,gamma参数不能训练?
收藏
我用Keras和Paddle搭建了同一个网络,他们的 Total params :556,383
但是 Keras 的 Non-trainable params 是1,792
而Paddle搭建的Non-trainable params 是3,584
我查看了网络结构,发现归一化层的参数被当作非训练参数了,我该如何调整?
我查阅了Paddle的文档,发现并没有对beta 和 gamma的叙述
Paddle的网络框架
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
Conv2D-64 [[1, 31, 32, 32]] [1, 64, 32, 32] 17,920
ReLU-60 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-43 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-46 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-44 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-44 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-47 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-45 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-45 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-48 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-46 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-46 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-49 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-47 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-47 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-48 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-48 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-49 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-49 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-57 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-58 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-59 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-60 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-61 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-57 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-62 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-58 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-63 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928
BatchNorm-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 256
ReLU-59 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
DHSIS_block-57 [[1, 64, 32, 32]] [1, 64, 32, 32] 0
Conv2D-65 [[1, 64, 32, 32]] [1, 31, 32, 32] 17,887
===========================================================================
Total params: 556,383
Trainable params: 552,799
Non-trainable params: 3,584
---------------------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 29.24
Params size (MB): 2.12
Estimated Total Size (MB): 31.49
---------------------------------------------------------------------------
Keras(tf 1.5.0) 的网络框架:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) (None, 32, 32, 31) 0
_________________________________________________________________
conv11 (Conv2D) (None, 32, 32, 64) 17920
_________________________________________________________________
conv11_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv12 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv12_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv12_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv13 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv13_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv13_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv14 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv14_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv14relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv15 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv15_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv15_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv16 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv16_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv16_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv17 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv17_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv17_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv18 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv18_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv18_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv19 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv19_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv19_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv20 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv20_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv20_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv21 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv21_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv21_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv22 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv22_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv22_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv23 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv23_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv23_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv24 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv24_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv24_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv25 (Conv2D) (None, 32, 32, 64) 36928
_________________________________________________________________
conv25_bn (BatchNormalizatio (None, 32, 32, 64) 256
_________________________________________________________________
conv25_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv30 (Conv2D) (None, 32, 32, 31) 17887
=================================================================
Total params: 556,383
Trainable params: 554,591
Non-trainable params: 1,792
_________________________________________________________________
XiangLiu
已解决
5#
回复于2022-06
正常使用nn.BatchNorm2D(num_channels),在net.train()模式下beat、gamma是训练的,mean,variance也是更新的。在net.eval()模式下beta、gamma不再更新,mean、variance使用训练阶段获得的最终均值方差。
0
收藏
请登录后评论
请问疑惑解决了吗,我也发现了这个问题
没解决
经过测试,只是显示参数不同,实际计算结果相同。
正常使用nn.BatchNorm2D(num_channels),在net.train()模式下beat、gamma是训练的,mean,variance也是更新的。在net.eval()模式下beta、gamma不再更新,mean、variance使用训练阶段获得的最终均值方差。
长知识了