实现生成对抗网络的问题
收藏
我在尝试用动态图去实现生成对抗网络的时候,就如我下面这段主函数,在每一轮结束,我用测试函数带入生成函数G,得出来的图并没有发生变化。所以我就觉得可能是我的主函数写错了,导致生成器和判别器的模型参数并没有更新,有没有大佬帮我看看,我这样的写法是不是对的
if __name__=='__main__': place=fluid.CUDAPlace(0) with fluid.dygraph.guard(place): epochs=5 test_data=fluid.dygraph.to_variable(test_z) model_d=Discriminator('Discriminator') model_g=Generator('Generator') for epoch in range(epochs): for i,real_image in enumerate(mnist_generator()): #训练判别器D识别G生成的假图片 model_d.train() model_g.eval() data_g=np.array(next(z_generator)) data_g=fluid.dygraph.to_variable(data_g) opt=fluid.optimizer.Adam(learning_rate=0.0002,parameter_list=model_d.parameters()) d=model_d(model_g(data_g)) avg_loss=loss(d,fluid.dygraph.to_variable( np.zeros(d.shape).astype('float32'))) #计算损失 avg_loss.backward() opt.minimize(avg_loss) model_d.clear_gradients() #训练D识别真图片 model_d.train() data_d=np.array(real_image) data_d=fluid.dygraph.to_variable(data_d) opt=fluid.optimizer.Adam(learning_rate=0.0002,parameter_list=model_d.parameters()) d=model_d(data_d) avg_loss=loss(d,fluid.dygraph.to_variable( #计算损失 np.ones(d.shape).astype('float32'))) avg_loss.backward() opt.minimize(avg_loss) model_d.clear_gradients() #训练G生成假图片 model_g.train() model_d.eval() data_g=np.array(next(z_generator)) data_g=fluid.dygraph.to_variable(data_g) g=model_g(data_g) d=model_d(g) avg_loss=loss(d,fluid.dygraph.to_variable(np.ones(d.shape).astype('float32'))) avg_loss.backward() opt=fluid.optimizer.Adam(learning_rate=0.0002,parameter_list=model_g.parameters()) opt.minimize(avg_loss) model_g.clear_gradients() #打印每一轮的生成器的结果 model_g.eval() img=model_g(test_data) show_image_grid(img,epoch) fluid.dygraph.save_dygraph(model_g.state_dict(),'model_g') fluid.dygraph.save_dygraph(model_d.state_dict(),'model_d')
0
收藏
请登录后评论
我刚刚试了试打印了生成器的参数发现真的没有更新,请问一下是哪错了呢
建议去 https://ai.baidu.com/forum/topic/list/168 这里问一下哈. 那头懂飞桨的人多一些.
好的谢谢了,我刚刚又尝试了一下,发现 .eval 这个模式好像不能要,不过还在琢磨中
把你Discriminator的代码发上来
最好把Generator也发上来
这是我的这两个函数,我已经发现问题所在了,就是参数没更新好像是我不应该设置 model.eval() 模式,然后最开始我的Generator函数的最后一层的激活函数设置成relu了,刚刚看了一些博客说要设置成tanh。这样我的网络就跑出来了。
生成器和判别器的模型参数没有更新是正常的
逻辑稍微有些bug
对抗生成网络可以看看这个项目
一文搞懂生成对抗网络之经典GAN(动态图、VisualDL2.0)
https://aistudio.baidu.com/aistudio/projectdetail/551962
当时我也遇到相似的问题。你的GAN网络稳定么?有没有什么改进的思路
mark 一下,学习学习
感谢分享,好东西
弱弱请教一下,看看有没有老师帮忙解答一下困惑,关于cyclegan的,理论貌似看的比较清楚了,但是应该是还浮于表面。就理解不了一点,油画与照片的转化属于风格迁移,马和斑马的转化怎么是属于风格迁移呢?这两种转化感觉性质不一样啊。另外,这个训练和测试过程到底是什么样的呢?感觉看运行的代码像是运行一个命令,但是黑箱子里的过程还是没有看清楚。刚开始踏进来这个领域,希望获得大神指点。