Paddle的BatchNorm在组网时beta,gamma参数不能训练?
收藏
我用Keras和Paddle搭建了同一个网络,他们的 Total params :556,383
但是 Keras 的 Non-trainable params 是1,792
而Paddle搭建的Non-trainable params 是3,584
我查看了网络结构,发现归一化层的参数被当作非训练参数了,我该如何调整?
我查阅了Paddle的文档,发现并没有对beta 和 gamma的叙述
Paddle的网络框架
--------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =========================================================================== Conv2D-64 [[1, 31, 32, 32]] [1, 64, 32, 32] 17,920 ReLU-60 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-43 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-46 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-44 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-44 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-47 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-45 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-45 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-48 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-46 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-46 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-49 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-47 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-47 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-48 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-48 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-49 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-49 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-57 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-50 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-58 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-51 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-59 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-52 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-60 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-53 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-61 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-54 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-57 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-62 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-55 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-58 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-63 [[1, 64, 32, 32]] [1, 64, 32, 32] 36,928 BatchNorm-56 [[1, 64, 32, 32]] [1, 64, 32, 32] 256 ReLU-59 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 DHSIS_block-57 [[1, 64, 32, 32]] [1, 64, 32, 32] 0 Conv2D-65 [[1, 64, 32, 32]] [1, 31, 32, 32] 17,887 =========================================================================== Total params: 556,383 Trainable params: 552,799 Non-trainable params: 3,584 --------------------------------------------------------------------------- Input size (MB): 0.12 Forward/backward pass size (MB): 29.24 Params size (MB): 2.12 Estimated Total Size (MB): 31.49 ---------------------------------------------------------------------------
Keras(tf 1.5.0) 的网络框架:
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) (None, 32, 32, 31) 0 _________________________________________________________________ conv11 (Conv2D) (None, 32, 32, 64) 17920 _________________________________________________________________ conv11_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv12 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv12_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv12_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv13 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv13_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv13_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv14 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv14_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv14relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv15 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv15_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv15_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv16 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv16_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv16_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv17 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv17_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv17_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv18 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv18_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv18_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv19 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv19_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv19_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv20 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv20_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv20_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv21 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv21_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv21_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv22 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv22_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv22_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv23 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv23_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv23_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv24 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv24_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv24_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv25 (Conv2D) (None, 32, 32, 64) 36928 _________________________________________________________________ conv25_bn (BatchNormalizatio (None, 32, 32, 64) 256 _________________________________________________________________ conv25_relu (Activation) (None, 32, 32, 64) 0 _________________________________________________________________ conv30 (Conv2D) (None, 32, 32, 31) 17887 ================================================================= Total params: 556,383 Trainable params: 554,591 Non-trainable params: 1,792 _________________________________________________________________
XiangLiu
已解决
5#
回复于2022-06
正常使用nn.BatchNorm2D(num_channels),在net.train()模式下beat、gamma是训练的,mean,variance也是更新的。在net.eval()模式下beta、gamma不再更新,mean、variance使用训练阶段获得的最终均值方差。
0
收藏
请登录后评论
请问疑惑解决了吗,我也发现了这个问题
没解决
经过测试,只是显示参数不同,实际计算结果相同。
正常使用nn.BatchNorm2D(num_channels),在net.train()模式下beat、gamma是训练的,mean,variance也是更新的。在net.eval()模式下beta、gamma不再更新,mean、variance使用训练阶段获得的最终均值方差。
长知识了