前言
就个人来说,比较喜欢从main函数入手,根据其运行的流程,一步一步找到每一步中涉及的参数以及函数的用意,最终摸清整个一个项目的流程和框架。
书接上文,上文书说到我们根据GitHub上的提示,可以使用下面的命令成功的运行了一个例子。
python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 --num_timesteps=1e6
于是我们顺藤摸瓜,找到了baselines文件夹下的ren.py文件,这个文件就是这个项目的入口。
run.py
参数设置
我们根据上面的这个命令,先简单的分析一下他的参数构成,这里主要的参数有两个:
- 参数alg=deepq表示的是采用的深度强化学习的算法,这里是指DQN。
- 参数env=PongNoFrameskip-v4表示的是采用的测试的环境,这里是用的是PongNoFrameskip-v4,这是一个乒乓球小游戏,通过控制球拍上下移动接球,没接到球的一方就会丢失一分,先打到21分的一方就获胜了。
main函数
接下来,我们看一下run.py具体是怎么实现的(只讲解一些与程序相关的较大的语句)
首先我们先找到main函数
def main(args):
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args(args)
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
configure_logger(args.log_path)
else:
rank = MPI.COMM_WORLD.Get_rank()
configure_logger(args.log_path, format_strs=[])
model, env = train(args, extra_args)
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,))
episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
while True:
if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones)
else:
actions, _, _, _ = model.step(obs)
obs, rew, done, _ = env.step(actions)
episode_rew += rew
env.render()
done_any = done.any() if isinstance(done, np.ndarray) else done
if done_any:
for i in np.nonzero(done)[0]:
print('episode_rew={}'.format(episode_rew[i]))
episode_rew[i] = 0
env.close()
return model
代码还是比较长的,我们一点点的分解来看一下
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args(args)
extra_args = parse_cmdline_kwargs(unknown_args)
这三句是与运行时传入参数的传递有关的,其中函数common_arg_parser() 是在common/cmd_util.py 文件中实现的,该函数也解释了各个参数的含义。
def common_arg_parser():
"""
Create an argparse.ArgumentParser for run_mujoco.py.
"""
parser = arg_parser()
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
parser.add_argument('--num_timesteps', type=float, default=1e6),
parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str)
parser.add_argument('--play', default=False, action='store_true')
return parser
函数arg_parser.parse_known_args() 是对传入参数的一个分类,在一开始的时候只接受设定好的参数(上面提到的common_arg_parser()中的设定参数)并保存到变量args中,而其他的未设定的参数被保存到unknown_args中,并传入parse_cmdline_kwargs()函数,以键值对的形式保存在**extra_args **中。
def parse_cmdline_kwargs(args):
'''
convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
'''
def parse(v):
assert isinstance(v, str)
try:
return eval(v)
except (NameError, SyntaxError):
return v
return {k: parse(v) for k,v in parse_unknown_args(args).items()}
接下来,main()函数中使用train()函数
model, env = train(args, extra_args)
该函数定义在run.py中
def train(args, extra_args):
env_type, env_id = get_env_type(args)
print('env_type: {}'.format(env_type))
total_timesteps = int(args.num_timesteps)
seed = args.seed
learn = get_learn_function(args.alg)
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
alg_kwargs.update(extra_args)
env = build_env(args)
if args.save_video_interval != 0:
env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
if args.network:
alg_kwargs['network'] = args.network
else:
if alg_kwargs.get('network') is None:
alg_kwargs['network'] = get_default_network(env_type)
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
model = learn(
env=env,
seed=seed,
total_timesteps=total_timesteps,
**alg_kwargs
)
return model, env
前面是根据传入的参数对各种变量进行赋值
env_type, env_id = get_env_type(args)
print('env_type: {}'.format(env_type))
total_timesteps = int(args.num_timesteps)
seed = args.seed
learn = get_learn_function(args.alg)
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
alg_kwargs.update(extra_args)
之后通过build_env()函数
env = build_env(args)
未完待续。。。
版权声明:本文为qq_45880533原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。