Paddle如何禁止部分层的参数更新?(类似requires_grad_(False))
收藏
各位,现在需要往主网络里加一个预训练的网络,然后预训练的网络不参与网络参数更新,训练主网络。
现在有一个预训练网络pretrain_net, torch可以通过pretrain_net.requires_grad_(False)实现。Paddle怎么实现?
现在模型已经封装好了,没法用定义层的 trainable来禁止。试过在 forward 中加入paddle.no_grad(),但是可训练参数还是包含着预训练网络。
谢谢!求指点!
0
收藏
请登录后评论
蹲个答案
直接把相应的层的requires_grad设置为false行吗?
呃呃,应该是相应的tensor
tensor.stop_gradient=True
可以试试,不保证正确……
设置之后,预训练网络之前的层就不更新了呀。我是把预训练网络嵌入到一个大网络中间了。
用法和pytorch应该是差不多的,你先看看它是怎么实现的
提前定义好不训练的类,在forward里传入不训练的类。不在主网里定义就可以了