将数据集划分为训练集与测试集

数据集划分

进行机器学习项目时我们经常会将数据集划分为训练集、测试集(和验证集)。为了保证数据的客观性,我们经常通过随机的方式打乱数据集,本文将提供一种随机分配数据集的方式。

环境

python 3

代码部分

进入数据集目录(默认所有图片存在同一文件夹内,如果储存在不同子文件夹内,请用os.walk())
运行以下python代码。数据集会在此文件夹内分为三个文件夹:train, val, test

import os,shutil,random

val_rate=1
test_rate=2
train_rate=7
total_rate=val_rate+test_rate+train_rate

if not os.path.exists('train'):
    os.mkdir('train')
if not os.path.exists('test'):
    os.mkdir('test')
if not os.path.exists('val'):
    os.mkdir('val')
    
files = os.listdir()
files.remove('train')
files.remove('test')
files.remove('val')
lenf = len(files)
lenval = (lenf*val_rate)//total_rate
lentest = (lenf*test_rate)//total_rate
lentrain = lenf - lenval - lentest

indv = random.sample(range(lenf),lenval)
for i in indv:
    shutil.move(files[i],'val')
    
    #如果你需要同时移动标注(annotations)文件,假设你的标注文件为csv文件,
    #与图片存储在一起,要将其与图像一并移入val文件夹
    #shutil.move(files[i][:-4]+'.csv','val') 

files = os.listdir()
files.remove('train')
files.remove('test')
files.remove('val')
lenf = len(files)
indtest = random.sample(range(lenf),lentest)
for i in indtest:
    shutil.move(files[i],'test')#如果不想改变原位置数据,可以用shutil.copyfile()

files = os.listdir()
files.remove('train')
files.remove('test')
files.remove('val')
for i in files:
    shutil.move(i,'train')

这里还有一个可以每个文件夹的比例来划分数据集的代码,比如你做花卉分类,目标是区分月季、玫瑰、蔷薇、鸢尾。你的数据文件夹架构为

|- 花卉
	|- 月季
		|- 俯视摄影
		|- 仰视摄影
		|- 侧面摄影
	|- 蔷薇
		|- 俯视摄影
		|- 仰视摄影
		|- 侧面摄影
	|- ······

如果你想让每个子文件夹的数据,在训练集和测试集的比例也保持为8:2,那么

import os,shutil,random

def split(root_path,train_test_rate=[8,2]):        
    output_path_train=os.path.join(root_path+"_splited","train")    
    output_path_test=os.path.join(root_path+"_splited","test")    
    train_rate=train_test_rate[0]    
    test_rate=train_test_rate[1]    
    for h,d,f in os.walk(root_path):        
        lenf=len(f)
        f_train=random.sample(f,int(len(f)//((train_rate+test_rate)/train_rate)))        
        f_test=f[:]        
        output_dir_train=h.replace(root_path,output_path_train)        
        output_dir_test=h.replace(root_path,output_path_test)        
        os.makedirs(output_dir_train,exist_ok=True)        
        os.makedirs(output_dir_test,exist_ok=True)        
        a=0        
        lentrain=len(f_train)        
        for fts in f_train:            
            f_test.remove(fts)            
            ori_path=os.path.join(h,fts)            
            new_path=os.path.join(output_dir_train,fts)            
            shutil.copyfile(ori_path,new_path)            
            a+=1            
            print(h,": ",a,lentrain)        
        a=0        
        lentest=len(f_test)        
        for fts in f_test:            
            ori_path=os.path.join(h,fts)            
            new_path=os.path.join(output_dir_test,fts)            
            shutil.copyfile(ori_path,new_path)            
            a+=1            
            print(h,": ",a,lentest)

代码为后期整理,如有bug感谢回复反馈,不定期更新

更新日志2021/5/12 更新了文档说明

版权声明:本文为qq_43199053原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。