1-baselines/run.py解读

1-baselines/run.py解读

前言

就个人来说,比较喜欢从main函数入手,根据其运行的流程,一步一步找到每一步中涉及的参数以及函数的用意,最终摸清整个一个项目的流程和框架。
书接上文,上文书说到我们根据GitHub上的提示,可以使用下面的命令成功的运行了一个例子。

 python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 --num_timesteps=1e6

于是我们顺藤摸瓜,找到了baselines文件夹下的ren.py文件,这个文件就是这个项目的入口。

run.py

参数设置

我们根据上面的这个命令,先简单的分析一下他的参数构成,这里主要的参数有两个:

  • 参数alg=deepq表示的是采用的深度强化学习的算法,这里是指DQN。
  • 参数env=PongNoFrameskip-v4表示的是采用的测试的环境,这里是用的是PongNoFrameskip-v4,这是一个乒乓球小游戏,通过控制球拍上下移动接球,没接到球的一方就会丢失一分,先打到21分的一方就获胜了。

main函数

接下来,我们看一下run.py具体是怎么实现的(只讲解一些与程序相关的较大的语句
首先我们先找到main函数

def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        configure_logger(args.log_path)
    else:
        rank = MPI.COMM_WORLD.Get_rank()
        configure_logger(args.log_path, format_strs=[])

    model, env = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs,S=state, M=dones)
            else:
                actions, _, _, _ = model.step(obs)

            obs, rew, done, _ = env.step(actions)
            episode_rew += rew
            env.render()
            done_any = done.any() if isinstance(done, np.ndarray) else done
            if done_any:
                for i in np.nonzero(done)[0]:
                    print('episode_rew={}'.format(episode_rew[i]))
                    episode_rew[i] = 0

    env.close()

    return model

代码还是比较长的,我们一点点的分解来看一下

	arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

这三句是与运行时传入参数的传递有关的,其中函数common_arg_parser() 是在common/cmd_util.py 文件中实现的,该函数也解释了各个参数的含义。

def common_arg_parser():
    """
    Create an argparse.ArgumentParser for run_mujoco.py.
    """
    parser = arg_parser()
    parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
    parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
    parser.add_argument('--seed', help='RNG seed', type=int, default=None)
    parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
    parser.add_argument('--num_timesteps', type=float, default=1e6),
    parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
    parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
    parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
    parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
    parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
    parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
    parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
    parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str)
    parser.add_argument('--play', default=False, action='store_true')
    return parser

函数arg_parser.parse_known_args() 是对传入参数的一个分类,在一开始的时候只接受设定好的参数(上面提到的common_arg_parser()中的设定参数)并保存到变量args中,而其他的未设定的参数被保存到unknown_args中,并传入parse_cmdline_kwargs()函数,以键值对的形式保存在**extra_args **中。

def parse_cmdline_kwargs(args):
    '''
    convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
    '''
    def parse(v):

        assert isinstance(v, str)
        try:
            return eval(v)
        except (NameError, SyntaxError):
            return v

    return {k: parse(v) for k,v in parse_unknown_args(args).items()}

接下来,main()函数中使用train()函数

model, env = train(args, extra_args)

该函数定义在run.py中

def train(args, extra_args):
    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)
    if args.save_video_interval != 0:
        env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    model = learn(
        env=env,
        seed=seed,
        total_timesteps=total_timesteps,
        **alg_kwargs
    )

    return model, env

前面是根据传入的参数对各种变量进行赋值

    env_type, env_id = get_env_type(args)
    print('env_type: {}'.format(env_type))

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

之后通过build_env()函数

env = build_env(args)

未完待续。。。


版权声明:本文为qq_45880533原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。