『NLP经典项目集』10 - 快递单信息抽取 模型训练 参数错误
收藏
快递单信息抽取卡在了模型训练这一步,在线上运行是可以过的,但是在windows上运行时就报错了
model.fit(train_data=train_loader,
eval_data=dev_loader,
epochs=10,
save_dir='./results',
log_freq=1)
我查了一下,发现在网络构建这一步的forward函数中传参的类型发生了变化,线上运行时x和lens的dtype都是paddle.int64,但是在windows上就变成了paddle.int32,经过该函数处理后,lens的类型没变,返回参数pred的类型却是paddle.int64,这怎么处理啊,传参是在model里传过来的,我核对过版本号啊
def forward(self, x, lens):
print(x.dtype,lens.dtype)
embs = self.word_emb(x)
output, _ = self.gru(embs)
output = self.fc(output)
_, pred = self.decoder(output, lens)
print(output.dtype,pred.dtype,lens.dtype)
return output, lens, pred
0
收藏
请登录后评论
感觉这应该是一个 bug。这个,你到 Github 给 PaddleNLP 提 issue 吧~
我刚才试了一下,在 Windows 系统下创建的数据类型为整型的 Tensor 默认是 int32 的,而在 Ubuntu 系统下,则默认是 int64的。而浮点数则没有问题,都是 float32。这里的坑主要在那个解码器,就是 ViterbiDecoder 。看了一下 ViterbiDecoder 的源码,它里面的整数是显示用的 int64。所以,你要是在 Windows 系统下运行,你需要将 lens 的数据类型显示的转换为 int64:
刚才翻了一下 Github 上的 issue 发现有个跟你类似的问题:
在训练过程中数据类型int64_t没有 · Issue #31413 ·
https://github.com/PaddlePaddle/Paddle/issues/31413
项目如何拿到windows上运行的呢