多网络模型的一个坑,提示绕行~~
收藏
原来训练分类模型这样单个网络的模型时,一般更新梯度这么写的:
da_loss.backward() # 反向更新梯度
d_a_optimizer.step() # 更新模型权重
d_a_optimizer.clear_grad() # 清除梯度
只有一个网络时没问题,但在cyclegan这样的多网络模型中就会出错了,应该先清楚梯度再更新梯度,如下写法:
d_a_optimizer.clear_grad() # 清除梯度
da_loss.backward() # 反向更新梯度
d_a_optimizer.step() # 更新模型权重
0
收藏
请登录后评论
以前1.8版本似乎能容错
现在2.0以后版本越来越严谨,代码错不得~~