parl.layers.batch_norm
收藏
请问下强化学习框架 parl.layers.batch_norm这个函数怎么使用啊,能不能帮我在全连接层里面加一个bn
class Model(parl.Model): def __init__(self, act_dim): hid1_size = 128 hid2_size = 128 # 3层全连接网络 self.fc1 = layers.fc(size=hid1_size, act='relu') self.fc2 = layers.fc(size=hid2_size, act='relu') self.fc3 = layers.fc(size=act_dim, act=None) def value(self, obs): # 定义网络 # 输入state,输出所有action对应的Q,[Q(s,a1), Q(s,a2), Q(s,a3)...] h1 = self.fc1(obs) h2 = self.fc2(h1) Q = self.fc3(h2) return Q
0
收藏
请登录后评论
还可以带参数。具体可以看一下这个文件:
parl/core/fluid/layers/layer_wrappers.py
谢谢,昨天在别的地方找到答案了。要注意版本问题,parl.layers.batch_norm要配paddle1.6.1的版本
不客气,问题解决就好。