DQN训练飞翔的小鸟
收藏
源码地址:https://github.com/yeyupiaoling/ReinforcementLearning/tree/main/course2
请在真实电脑运行。
import os import cv2 import parl import numpy as np import game.wrapped_flappy_bird as flappyBird from parl.utils import logger from model import Model from agent import Agent from replay_memory import ReplayMemory LEARN_FREQ = 5 # 更新参数步数 MEMORY_SIZE = 20000 # 内存记忆 MEMORY_WARMUP_SIZE = 200 # 热身大小 BATCH_SIZE = 64 # batch大小 LEARNING_RATE = 0.0005 # 学习率大小 GAMMA = 0.99 # 奖励系数 E_GREED = 0.1 # 探索初始概率 E_GREED_DECREMENT = 1e-6 # 在训练过程中,降低探索的概率 MAX_EPISODE = 10000 # 训练次数 RESIZE_SHAPE = (1, 224, 224) # 训练缩放的大小,减少模型计算,原大小(288, 512) SAVE_MODEL_PATH = "models/model.ckpt" # 保存模型路径 # 图像预处理 def preprocess(observation): # 缩放图像 observation = cv2.resize(observation, (RESIZE_SHAPE[1], RESIZE_SHAPE[2])) # 把图像转成灰度图 observation = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY) # 图像转换成非黑即白的图像 ret, observation = cv2.threshold(observation, 1, 255, cv2.THRESH_BINARY) # 显示处理过的图像 cv2.imshow("preprocess", observation) cv2.waitKey(1) observation = np.expand_dims(observation, axis=0) observation = observation / 255.0 return observation def run_train(agent, env, rpm): total_reward = 0 obs = env.reset() obs = preprocess(obs) step = 0 while True: step += 1 # 获取随机动作和执行游戏 action = agent.sample(obs, env) next_obs, reward, isOver = env.step(action, is_train=True) next_obs = preprocess(next_obs) # 记录数据 rpm.append((obs, [action], reward, next_obs, isOver)) # 训练模型 if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0): (batch_obs, batch_action, batch_reward, batch_next_obs, batch_isOver) = rpm.sample(BATCH_SIZE) train_loss = agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs, batch_isOver) total_reward += reward obs = next_obs # 结束游戏 if isOver: break return total_reward # 评估模型 def evaluate(agent, env): obs = env.reset() episode_reward = 0 isOver = False while not isOver: obs = preprocess(obs) action = agent.predict(obs) obs, reward, isOver = env.step(action) episode_reward += reward return episode_reward def main(): # 初始化游戏 env = flappyBird.GameState() # 图像输入形状和动作维度 obs_dim = RESIZE_SHAPE action_dim = env.action_dim # 创建存储执行游戏的内存 rpm = ReplayMemory(MEMORY_SIZE) # 创建模型 model = Model(act_dim=action_dim) algorithm = parl.algorithms.DQN(model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE) agent = Agent(algorithm=algorithm, obs_dim=obs_dim, act_dim=action_dim, e_greed=E_GREED, e_greed_decrement=E_GREED_DECREMENT) # 预热 print("开始预热...") while len(rpm) < MEMORY_WARMUP_SIZE: run_train(agent, env, rpm) # 开始训练 print("开始正式训练...") episode = 0 while episode < MAX_EPISODE: # 训练 for i in range(50): train_reward = run_train(agent, env, rpm) episode += 1 logger.info('Episode: {}, Reward: {:.2f}, e_greed: {:.2f}'.format(episode, train_reward, agent.e_greed)) # 评估 eval_reward = evaluate(agent, env) logger.info('Episode: {}, Evaluate reward:{:.2f}'.format(episode, eval_reward)) # 保存模型 if not os.path.exists(os.path.dirname(SAVE_MODEL_PATH)): os.makedirs(os.path.dirname(SAVE_MODEL_PATH)) agent.save(SAVE_MODEL_PATH) if __name__ == '__main__': main()
0
收藏
请登录后评论
游戏环境在Github源码中