self.rnn = nn.SimpleRNN(256, 256, num_layers=2, direction='forward',dropout=0.5) 前面这个forward 参数,不是双向的,但是后面的代码注释是双向的意思,尝试删除下面这句代码就会出错,
hidden = paddle.concat((hidden[-2,:,:], hidden[-1,:,:]), axis = 1
# 定义RNN网络
class MyRNN(paddle.nn.Layer):
def __init__(self):
super(MyRNN, self).__init__()
self.embedding = nn.Embedding(vocab_size, 256)
self.rnn = nn.SimpleRNN(256, 256, num_layers=2, direction='forward',dropout=0.5)
self.linear = nn.Linear(in_features=256*2, out_features=2)
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
emb = self.dropout(self.embedding(inputs))
#output形状大小为[batch_size,seq_len,num_directions * hidden_size]
#hidden形状大小为[num_layers * num_directions, batch_size, hidden_size]
#把前向的hidden与后向的hidden合并在一起
output, hidden = self.rnn(emb)
hidden = paddle.concat((hidden[-2,:,:], hidden[-1,:,:]), axis = 1)
#hidden形状大小为[batch_size, hidden_size * num_directions]
hidden = self.dropout(hidden)
return self.linear(hidden)
direction (str,可选) - 网络迭代方向,可设置为forward或bidirect(或bidirectional)。默认为forward。
可以看看API文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SimpleRNN_cn.html
多谢回复,我知道可选参数,我想问的是:为什以删掉那句代码就报错