多形态信息抽取 官方给的gru+crf代码跑不通
收藏
有错啊,一直提示什么 _init_的参数错误,有人跑通吗?求分享学习啊
0
收藏
请登录后评论
跑通了事件抽取,官方给的那一段参考代码里有挺多小坑,是不能直接拿来用的,建议仔细看一下改一改,(举个最明显的例子就是他__init__里用的是ernie但是下面forward用的是bert)。
求分享学习啊。lengths参数调了半天不行,加到forward里提示None跟int不匹配,设置int又提示size 8.244不匹配8.。这bug调的太痛苦了。
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None, lengths=None):
sequence_output, _ = self.ernie(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
sequence_output = self.dropout(sequence_output)
bigru_output, _ = self.gru(sequence_output)
emission = self.fc(bigru_output)
_, prediction = self.viterbi_decoder(emission, lengths)
if labels is not None:
loss = self.crf_loss(emission, lengths, prediction, labels)
return loss, prediction, labels
else:
return lengths,prediction