求助,我在使用混合精度训练时,遇到了以下问题
收藏
AssertionError: Variable dtype not match, Variable [ dygraph_tmp_11 ] need tensor with dtype float32 but load tensor with dtype float16 可amp不是本身就要使用float16计算的嘛?
附上部分代码:
class D_m(nn.Layer):
def __init__(self):
super().__init__()
self.from_rgb=nn.utils.spectral_norm(conv(3,16,1,bias=True))
......
def forward(self,x):
x=self.from_rgb(x)
.....
return x
D=D_m()
......
with paddle.amp.auto_cast():
r=D(batch)
十分感谢!
0
收藏
请登录后评论
一般paddle都是要求双精度float32
我理解框架自动进行混合精度运算
参考文档:
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html#zidonghunhejingduxunlian