pytorch Kfold数据集划分

今天想使用K折方法进行训练,发现 pytorch dataloader 中没有需要的一键操作的代码,我自己写了一个。

首先得到数据量,然后使用 sklearn.model_selection 的 KFold 方法划分数据索引,最后使用 torch.utils.data.dataset.Subset 方法得到划分后的子数据集。代码思路如下。

import torch
from sklearn.model_selection import KFold

data_induce = np.arange(0, data_loader_old.dataset.length)
kf = KFold(n_splits=5)

for train_index, val_index in kf.split(data_induce):
    train_subset = torch.utils.data.dataset.Subset(Dataset(params), train_index)
    val_subset = torch.utils.data.dataset.Subset(Dataset(params), val_index)
    data_loaders['train'] = torch.utils.data.DataLoader(train_subset, ...)
    data_loaders['val'] = data.pair_provider_subset(val_subset, ...)

参考:https://scikit-learn.org/stable/modules/cross_validation.html

https://stackoverflow.com/questions/60883696/k-fold-cross-validation-using-dataloaders-in-pytorch


版权声明:本文为qq_44761480原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。