数据集划分
进行机器学习项目时我们经常会将数据集划分为训练集、测试集(和验证集)。为了保证数据的客观性,我们经常通过随机的方式打乱数据集,本文将提供一种随机分配数据集的方式。
环境
python 3
代码部分
进入数据集目录(默认所有图片存在同一文件夹内,如果储存在不同子文件夹内,请用os.walk())
运行以下python代码。数据集会在此文件夹内分为三个文件夹:train, val, test
import os,shutil,random
val_rate=1
test_rate=2
train_rate=7
total_rate=val_rate+test_rate+train_rate
if not os.path.exists('train'):
os.mkdir('train')
if not os.path.exists('test'):
os.mkdir('test')
if not os.path.exists('val'):
os.mkdir('val')
files = os.listdir()
files.remove('train')
files.remove('test')
files.remove('val')
lenf = len(files)
lenval = (lenf*val_rate)//total_rate
lentest = (lenf*test_rate)//total_rate
lentrain = lenf - lenval - lentest
indv = random.sample(range(lenf),lenval)
for i in indv:
shutil.move(files[i],'val')
#如果你需要同时移动标注(annotations)文件,假设你的标注文件为csv文件,
#与图片存储在一起,要将其与图像一并移入val文件夹
#shutil.move(files[i][:-4]+'.csv','val')
files = os.listdir()
files.remove('train')
files.remove('test')
files.remove('val')
lenf = len(files)
indtest = random.sample(range(lenf),lentest)
for i in indtest:
shutil.move(files[i],'test')#如果不想改变原位置数据,可以用shutil.copyfile()
files = os.listdir()
files.remove('train')
files.remove('test')
files.remove('val')
for i in files:
shutil.move(i,'train')
这里还有一个可以每个文件夹的比例来划分数据集的代码,比如你做花卉分类,目标是区分月季、玫瑰、蔷薇、鸢尾。你的数据文件夹架构为
|- 花卉
|- 月季
|- 俯视摄影
|- 仰视摄影
|- 侧面摄影
|- 蔷薇
|- 俯视摄影
|- 仰视摄影
|- 侧面摄影
|- ······
如果你想让每个子文件夹的数据,在训练集和测试集的比例也保持为8:2,那么
import os,shutil,random
def split(root_path,train_test_rate=[8,2]):
output_path_train=os.path.join(root_path+"_splited","train")
output_path_test=os.path.join(root_path+"_splited","test")
train_rate=train_test_rate[0]
test_rate=train_test_rate[1]
for h,d,f in os.walk(root_path):
lenf=len(f)
f_train=random.sample(f,int(len(f)//((train_rate+test_rate)/train_rate)))
f_test=f[:]
output_dir_train=h.replace(root_path,output_path_train)
output_dir_test=h.replace(root_path,output_path_test)
os.makedirs(output_dir_train,exist_ok=True)
os.makedirs(output_dir_test,exist_ok=True)
a=0
lentrain=len(f_train)
for fts in f_train:
f_test.remove(fts)
ori_path=os.path.join(h,fts)
new_path=os.path.join(output_dir_train,fts)
shutil.copyfile(ori_path,new_path)
a+=1
print(h,": ",a,lentrain)
a=0
lentest=len(f_test)
for fts in f_test:
ori_path=os.path.join(h,fts)
new_path=os.path.join(output_dir_test,fts)
shutil.copyfile(ori_path,new_path)
a+=1
print(h,": ",a,lentest)
代码为后期整理,如有bug感谢回复反馈,不定期更新
更新日志2021/5/12 更新了文档说明
版权声明:本文为qq_43199053原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。