首页 Paddle框架 帖子详情
PaddlePaddle网络参数如何固定初始化 已解决
收藏
快速回复
Paddle框架 其他学习资料 914 2
PaddlePaddle网络参数如何固定初始化 已解决
收藏
快速回复
Paddle框架 其他学习资料 914 2

搭建和训练神经网络时,一般会随机初始化网络中各层的参数,然后再在训练数据中学习更新。

但是,在如小样本学习等任务中,一些方法规定了网络参数不能随机初始化。那么如何使用特征数据去初始化网络中的参数呢?

答案是在初始化参数时候,将paddle.nn.initializer.Assign()赋值给default_initializer参数,具体如下:

# -*- coding: utf-8 -*-
# @Time    : 2021/12/5 12:58
# @Author  : He Ruizhi
# @File    : case.py
# @Software: PyCharm

import paddle
import numpy as np


class MyNet(paddle.nn.Layer):
    def __init__(self):
        super(MyNet, self).__init__()

        parameter_matrix = np.array([[1, 2],
                                     [3, 4],
                                     [5, 6]], dtype='float32')
        w = self.create_parameter(shape=list(parameter_matrix.shape),
                                  default_initializer=paddle.nn.initializer.Assign(parameter_matrix))
        b = self.create_parameter(shape=[parameter_matrix.shape[-1]], is_bias=True)
        self.add_parameter('w', w)
        self.add_parameter('b', b)

    def forward(self, x):
        return paddle.matmul(x, self.w) + self.b


if __name__ == "__main__":
    model = MyNet()
    paddle.summary(model, input_size=(None, 3))

    # 网络中参数的值
    for parameter in model.parameters():
        print(parameter.numpy())

通过打印输出可以看到,网络中的参数已经被指定初始化为特定的参数,而不是随机初始化的:

---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
    MyNet-1           [[1, 3]]              [1, 2]               8       
===========================================================================
Total params: 8
Trainable params: 8
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
---------------------------------------------------------------------------

[[1. 2.]
 [3. 4.]
 [5. 6.]]
[0. 0.]
DeepGeGe
已解决
2# 回复于2021-12
除了上述方法,也可以使用另一个低级别的API:paddle.static.create_parameter(shape, dtype, name=None, attr=None, is_bias=False, default_initializer=None)
展开
0
收藏
回复
全部评论(2)
时间顺序
DeepGeGe
#2 回复于2021-12

除了上述方法,也可以使用另一个低级别的API:paddle.static.create_parameter(shape, dtype, name=None, attr=None, is_bias=False, default_initializer=None)

0
回复
DeepGeGe
#3 回复于2021-12

那么如何使用特征数据去初始化网络中的参数呢?

这句话中的【特征数据】为【特定数据】,打错了。

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户