我正在下载ernie_gen,但是显示KeyError: '[MASK]'
以下是代码
import os
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D
from ernie.modeling_ernie import ErnieModelForPretraining, ErnieModel
from ernie.tokenizing_ernie import ErnieTokenizer
model_dir = 'ernie-gen-base-en'
D.guard().__enter__()
tokenizer = ErnieTokenizer.from_pretrained(model_dir)
rev_dict = {v: k for k, v in tokenizer.vocab.items()}
rev_dict[tokenizer.pad_id] = '' # replace [PAD]
rev_dict[tokenizer.sep_id] = '' # replace [PAD]
rev_dict[tokenizer.unk_id] = '' # replace [PAD]
class ErnieCloze(ErnieModelForPretraining):
def __init__(self, *args, **kwargs):
super(ErnieCloze, self).__init__(*args, **kwargs)
del self.pooler_heads
def forward(self, src_ids, *args, **kwargs):
pooled, encoded = ErnieModel.forward(self, src_ids, *args, **kwargs)
encoded_2d = L.gather_nd(encoded, L.where(src_ids == mask_id))
encoded_2d = self.mlm(encoded_2d)
encoded_2d = self.mlm_ln(encoded_2d)
logits_2d = L.matmul(encoded_2d, self.word_emb.weight, transpose_y=True) + self.mlm_bias
return logits_2d
@np.vectorize
def rev_lookup(i):
return rev_dict[i]