parl.layers.batch_norm
收藏
请问下强化学习框架 parl.layers.batch_norm这个函数怎么使用啊,能不能帮我在全连接层里面加一个bn
class Model(parl.Model):
def __init__(self, act_dim):
hid1_size = 128
hid2_size = 128
# 3层全连接网络
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=act_dim, act=None)
def value(self, obs):
# 定义网络
# 输入state,输出所有action对应的Q,[Q(s,a1), Q(s,a2), Q(s,a3)...]
h1 = self.fc1(obs)
h2 = self.fc2(h1)
Q = self.fc3(h2)
return Q
0
收藏
请登录后评论
还可以带参数。具体可以看一下这个文件:
parl/core/fluid/layers/layer_wrappers.py
谢谢,昨天在别的地方找到答案了。要注意版本问题,parl.layers.batch_norm要配paddle1.6.1的版本
不客气,问题解决就好。