首页 Paddle框架 帖子详情
Paddle如何禁止部分层的参数更新?(类似requires_grad_(False))
收藏
快速回复
Paddle框架 问答模型训练深度学习 4218 8
Paddle如何禁止部分层的参数更新?(类似requires_grad_(False))
收藏
快速回复
Paddle框架 问答模型训练深度学习 4218 8

各位,现在需要往主网络里加一个预训练的网络,然后预训练的网络不参与网络参数更新,训练主网络。

现在有一个预训练网络pretrain_net, torch可以通过pretrain_net.requires_grad_(False)实现。Paddle怎么实现?

现在模型已经封装好了,没法用定义层的 trainable来禁止。试过在 forward 中加入paddle.no_grad(),但是可训练参数还是包含着预训练网络。

谢谢!求指点!

0
收藏
回复
全部评论(8)
时间顺序
Lerbronjames
#2 回复于2022-08

蹲个答案

0
回复
玥亮
#3 回复于2022-09

直接把相应的层的requires_grad设置为false行吗?

0
回复
玥亮
#4 回复于2022-09
玥亮 #3
直接把相应的层的requires_grad设置为false行吗?

呃呃,应该是相应的tensor

0
回复
玥亮
#5 回复于2022-09
玥亮 #4
呃呃,应该是相应的tensor

tensor.stop_gradient=True

0
回复
玥亮
#6 回复于2022-09
玥亮 #5
tensor.stop_gradient=True

可以试试,不保证正确……

0
回复
S
Syuhen
#7 回复于2022-09
玥亮 #6
可以试试,不保证正确……

设置之后,预训练网络之前的层就不更新了呀。我是把预训练网络嵌入到一个大网络中间了。

0
回复
李长安
#8 回复于2022-09

用法和pytorch应该是差不多的,你先看看它是怎么实现的

0
回复
S
Syuhen
#9 回复于2022-09
蹲个答案

提前定义好不训练的类,在forward里传入不训练的类。不在主网里定义就可以了

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户