首页 Paddle框架 帖子详情
Paddle的BatchNorm在组网时beta,gamma参数不能训练? 已解决
收藏
快速回复
Paddle框架 问答学习资料 287 5
Paddle的BatchNorm在组网时beta,gamma参数不能训练? 已解决
收藏
快速回复
Paddle框架 问答学习资料 287 5

我用Keras和Paddle搭建了同一个网络,他们的 Total params :556,383

但是 Keras 的 Non-trainable params 是1,792

而Paddle搭建的Non-trainable params 是3,584

我查看了网络结构,发现归一化层的参数被当作非训练参数了,我该如何调整?

我查阅了Paddle的文档,发现并没有对beta 和 gamma的叙述

Paddle的网络框架

---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-64     [[1, 31, 32, 32]]     [1, 64, 32, 32]        17,920     
    ReLU-60      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-50     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-43    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-46      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-44   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-51     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-44    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-47      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-45   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-52     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-45    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-48      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-46   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-53     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-46    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-49      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-47   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-54     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-47    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-50      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-48   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-55     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-48    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-51      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-49   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-56     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-49    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-52      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-50   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-57     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-50    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-53      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-51   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-58     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-51    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-54      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-52   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-59     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-52    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-55      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-53   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-60     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-53    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-56      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-54   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-61     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-54    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-57      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-55   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-62     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-55    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-58      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-56   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-63     [[1, 64, 32, 32]]     [1, 64, 32, 32]        36,928     
 BatchNorm-56    [[1, 64, 32, 32]]     [1, 64, 32, 32]          256      
    ReLU-59      [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
DHSIS_block-57   [[1, 64, 32, 32]]     [1, 64, 32, 32]           0       
   Conv2D-65     [[1, 64, 32, 32]]     [1, 31, 32, 32]        17,887     
===========================================================================
Total params: 556,383
Trainable params: 552,799
Non-trainable params: 3,584
---------------------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 29.24
Params size (MB): 2.12
Estimated Total Size (MB): 31.49
---------------------------------------------------------------------------

Keras(tf 1.5.0) 的网络框架:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input (InputLayer)           (None, 32, 32, 31)        0         
_________________________________________________________________
conv11 (Conv2D)              (None, 32, 32, 64)        17920     
_________________________________________________________________
conv11_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv12 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv12_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv12_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv13 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv13_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv13_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv14 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv14_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv14relu (Activation)      (None, 32, 32, 64)        0         
_________________________________________________________________
conv15 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv15_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv15_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv16 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv16_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv16_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv17 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv17_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv17_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv18 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv18_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv18_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv19 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv19_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv19_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv20 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv20_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv20_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv21 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv21_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv21_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv22 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv22_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv22_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv23 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv23_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv23_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv24 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv24_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv24_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv25 (Conv2D)              (None, 32, 32, 64)        36928     
_________________________________________________________________
conv25_bn (BatchNormalizatio (None, 32, 32, 64)        256       
_________________________________________________________________
conv25_relu (Activation)     (None, 32, 32, 64)        0         
_________________________________________________________________
conv30 (Conv2D)              (None, 32, 32, 31)        17887     
=================================================================
Total params: 556,383
Trainable params: 554,591
Non-trainable params: 1,792
_________________________________________________________________
XiangLiu
已解决
5# 回复于2022-06
正常使用nn.BatchNorm2D(num_channels),在net.train()模式下beat、gamma是训练的,mean,variance也是更新的。在net.eval()模式下beta、gamma不再更新,mean、variance使用训练阶段获得的最终均值方差。
展开
0
收藏
回复
全部评论(5)
时间顺序
XiangLiu
#2 回复于2022-06

请问疑惑解决了吗,我也发现了这个问题

0
回复
S
Syuhen
#3 回复于2022-06
请问疑惑解决了吗,我也发现了这个问题 [图片]

没解决

0
回复
XiangLiu
#4 回复于2022-06
Syuhen #3
没解决

经过测试,只是显示参数不同,实际计算结果相同。

0
回复
XiangLiu
#5 回复于2022-06

正常使用nn.BatchNorm2D(num_channels),在net.train()模式下beat、gamma是训练的,mean,variance也是更新的。在net.eval()模式下beta、gamma不再更新,mean、variance使用训练阶段获得的最终均值方差。

0
回复
T
Tedoliy
#7 回复于2022-09

长知识了

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户