利用yaml文件更新argparse.ArgumentParser参数

目录

1--argparse介绍

2--利用yaml文件更新argparse参数值

3--参考


1--argparse介绍

argparse是python内置的一个命令行解析模块,其常见用法如下:

# 导入第三方库
import argparse

# 创建解析对象
parser = argparse.ArgumentParser(description = 'Used to describe the purpose of the object') 

# 添加参数
parser.add_argument(
    '--Parameter_name', # 参数名
    type = str,         # 参数类型
    default = None,     # 默认参数值
    help = 'The purpose of the parameter') # 参数的用途

# 实例化对象
p = parser.parse_args(args = [])

# 输出解析对象的参数
print(p.Parameter_name)

2--利用yaml文件更新argparse参数值

当解析对象的参数值过多,以及需要频繁调试更新参数值(例如深度学习调参)时,使用yaml文件更新argparse参数值能有效提高工作效率,下面将以代码的形式展示利用yaml文件更新argparse参数值:

# 导入第三方库
import argparse
import yaml
import os


# 创建对象并初始化参数
def get_parser():
    
    # 创建解析对象
    parser = argparse.ArgumentParser(description = 'A method to update parser parameters using yaml files') 
    
    # 添加参数1
    parser.add_argument(
        '--num_worker', 
        type = int,
        default = 4,
        help = 'the number of worker for data loader')
    
    # 添加参数2
    parser.add_argument(
        '--lr',
        type = float,
        default = 0.01,
        help = 'the learning rate of SGD')
    
    # 添加参数3
    parser.add_argument(
    '--batchsize',
    type = int,
    default = 64,
    help = 'the batchsize of training stage')
    
    # 返回解析对象
    return parser 



# 创建一个yaml文件
def creat_yaml():
    
    # yaml文件存放的内容
    caps = {
        'num_worker': 16,
        'lr': 0.05,
    }

    # yaml文件存放的路径
    yamlpath = os.path.join('./', 'test.yaml')

    # caps的内容写入yaml文件
    with open(yamlpath, "w", encoding = "utf-8") as f:
        yaml.dump(caps, f)

        
# main()函数    
def main():
    
    # 创建一个yaml文件
    creat_yaml()
    
    # 创建解析对象
    parser = get_parser()
    
    # 实例化对象
    p = parser.parse_args(args = [])
    
    # 输出参数原始默认值
    print('The default value of num_worker is: ', p.num_worker)
    print('The default value of lr is: ', p.lr)
    print('The default value of batchsize is: ', p.batchsize)
    print('##############################')
    
    # 导入创建的yaml文件
    with open('test.yaml', 'r') as f:
        default_arg = yaml.load(f)
        
    # 创建解析对象
    parser = get_parser() 
    
    # 利用yaml文件更新默认值
    parser.set_defaults(**default_arg)
    
    # 实例化对象
    p = parser.parse_args(args = [])

    # 输出更新后的参数值
    print('The updated value of num_worker is: ', p.num_worker)
    print('The updated value of lr is: ', p.lr)
    print('The updated value of batchsize is: ', p.batchsize) # batchsize并没有用yaml文件更新哦


# 执行main函数
main()

代码执行结果:

The default value of num_worker is:  4
The default value of lr is:  0.01
The default value of batchsize is:  64
##############################
The updated value of num_worker is:  16
The updated value of lr is:  0.05
The updated value of batchsize is:  64

3--参考

参考链接1


版权声明:本文为weixin_43863869原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。