实现生成对抗网络的问题
收藏
我在尝试用动态图去实现生成对抗网络的时候,就如我下面这段主函数,在每一轮结束,我用测试函数带入生成函数G,得出来的图并没有发生变化。所以我就觉得可能是我的主函数写错了,导致生成器和判别器的模型参数并没有更新,有没有大佬帮我看看,我这样的写法是不是对的
if __name__=='__main__':
place=fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
epochs=5
test_data=fluid.dygraph.to_variable(test_z)
model_d=Discriminator('Discriminator')
model_g=Generator('Generator')
for epoch in range(epochs):
for i,real_image in enumerate(mnist_generator()):
#训练判别器D识别G生成的假图片
model_d.train()
model_g.eval()
data_g=np.array(next(z_generator))
data_g=fluid.dygraph.to_variable(data_g)
opt=fluid.optimizer.Adam(learning_rate=0.0002,parameter_list=model_d.parameters())
d=model_d(model_g(data_g))
avg_loss=loss(d,fluid.dygraph.to_variable(
np.zeros(d.shape).astype('float32'))) #计算损失
avg_loss.backward()
opt.minimize(avg_loss)
model_d.clear_gradients()
#训练D识别真图片
model_d.train()
data_d=np.array(real_image)
data_d=fluid.dygraph.to_variable(data_d)
opt=fluid.optimizer.Adam(learning_rate=0.0002,parameter_list=model_d.parameters())
d=model_d(data_d)
avg_loss=loss(d,fluid.dygraph.to_variable( #计算损失
np.ones(d.shape).astype('float32')))
avg_loss.backward()
opt.minimize(avg_loss)
model_d.clear_gradients()
#训练G生成假图片
model_g.train()
model_d.eval()
data_g=np.array(next(z_generator))
data_g=fluid.dygraph.to_variable(data_g)
g=model_g(data_g)
d=model_d(g)
avg_loss=loss(d,fluid.dygraph.to_variable(np.ones(d.shape).astype('float32')))
avg_loss.backward()
opt=fluid.optimizer.Adam(learning_rate=0.0002,parameter_list=model_g.parameters())
opt.minimize(avg_loss)
model_g.clear_gradients()
#打印每一轮的生成器的结果
model_g.eval()
img=model_g(test_data)
show_image_grid(img,epoch)
fluid.dygraph.save_dygraph(model_g.state_dict(),'model_g')
fluid.dygraph.save_dygraph(model_d.state_dict(),'model_d')
0
收藏
请登录后评论
我刚刚试了试打印了生成器的参数发现真的没有更新,请问一下是哪错了呢
建议去 https://ai.baidu.com/forum/topic/list/168 这里问一下哈. 那头懂飞桨的人多一些.
好的谢谢了,我刚刚又尝试了一下,发现 .eval 这个模式好像不能要,不过还在琢磨中
把你Discriminator的代码发上来
最好把Generator也发上来
这是我的这两个函数,我已经发现问题所在了,就是参数没更新好像是我不应该设置 model.eval() 模式,然后最开始我的Generator函数的最后一层的激活函数设置成relu了,刚刚看了一些博客说要设置成tanh。这样我的网络就跑出来了。
生成器和判别器的模型参数没有更新是正常的
逻辑稍微有些bug
对抗生成网络可以看看这个项目
一文搞懂生成对抗网络之经典GAN(动态图、VisualDL2.0)
https://aistudio.baidu.com/aistudio/projectdetail/551962
当时我也遇到相似的问题。你的GAN网络稳定么?有没有什么改进的思路
mark 一下,学习学习
感谢分享,好东西
弱弱请教一下,看看有没有老师帮忙解答一下困惑,关于cyclegan的,理论貌似看的比较清楚了,但是应该是还浮于表面。就理解不了一点,油画与照片的转化属于风格迁移,马和斑马的转化怎么是属于风格迁移呢?这两种转化感觉性质不一样啊。另外,这个训练和测试过程到底是什么样的呢?感觉看运行的代码像是运行一个命令,但是黑箱子里的过程还是没有看清楚。刚开始踏进来这个领域,希望获得大神指点。