目录
1--argparse介绍
argparse是python内置的一个命令行解析模块,其常见用法如下:
# 导入第三方库
import argparse
# 创建解析对象
parser = argparse.ArgumentParser(description = 'Used to describe the purpose of the object')
# 添加参数
parser.add_argument(
'--Parameter_name', # 参数名
type = str, # 参数类型
default = None, # 默认参数值
help = 'The purpose of the parameter') # 参数的用途
# 实例化对象
p = parser.parse_args(args = [])
# 输出解析对象的参数
print(p.Parameter_name)
2--利用yaml文件更新argparse参数值
当解析对象的参数值过多,以及需要频繁调试更新参数值(例如深度学习调参)时,使用yaml文件更新argparse参数值能有效提高工作效率,下面将以代码的形式展示利用yaml文件更新argparse参数值:
# 导入第三方库
import argparse
import yaml
import os
# 创建对象并初始化参数
def get_parser():
# 创建解析对象
parser = argparse.ArgumentParser(description = 'A method to update parser parameters using yaml files')
# 添加参数1
parser.add_argument(
'--num_worker',
type = int,
default = 4,
help = 'the number of worker for data loader')
# 添加参数2
parser.add_argument(
'--lr',
type = float,
default = 0.01,
help = 'the learning rate of SGD')
# 添加参数3
parser.add_argument(
'--batchsize',
type = int,
default = 64,
help = 'the batchsize of training stage')
# 返回解析对象
return parser
# 创建一个yaml文件
def creat_yaml():
# yaml文件存放的内容
caps = {
'num_worker': 16,
'lr': 0.05,
}
# yaml文件存放的路径
yamlpath = os.path.join('./', 'test.yaml')
# caps的内容写入yaml文件
with open(yamlpath, "w", encoding = "utf-8") as f:
yaml.dump(caps, f)
# main()函数
def main():
# 创建一个yaml文件
creat_yaml()
# 创建解析对象
parser = get_parser()
# 实例化对象
p = parser.parse_args(args = [])
# 输出参数原始默认值
print('The default value of num_worker is: ', p.num_worker)
print('The default value of lr is: ', p.lr)
print('The default value of batchsize is: ', p.batchsize)
print('##############################')
# 导入创建的yaml文件
with open('test.yaml', 'r') as f:
default_arg = yaml.load(f)
# 创建解析对象
parser = get_parser()
# 利用yaml文件更新默认值
parser.set_defaults(**default_arg)
# 实例化对象
p = parser.parse_args(args = [])
# 输出更新后的参数值
print('The updated value of num_worker is: ', p.num_worker)
print('The updated value of lr is: ', p.lr)
print('The updated value of batchsize is: ', p.batchsize) # batchsize并没有用yaml文件更新哦
# 执行main函数
main()
代码执行结果:
The default value of num_worker is: 4
The default value of lr is: 0.01
The default value of batchsize is: 64
##############################
The updated value of num_worker is: 16
The updated value of lr is: 0.05
The updated value of batchsize is: 64
3--参考
版权声明:本文为weixin_43863869原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。