今天想使用K折方法进行训练,发现 pytorch dataloader 中没有需要的一键操作的代码,我自己写了一个。
首先得到数据量,然后使用 sklearn.model_selection 的 KFold 方法划分数据索引,最后使用 torch.utils.data.dataset.Subset 方法得到划分后的子数据集。代码思路如下。
import torch
from sklearn.model_selection import KFold
data_induce = np.arange(0, data_loader_old.dataset.length)
kf = KFold(n_splits=5)
for train_index, val_index in kf.split(data_induce):
train_subset = torch.utils.data.dataset.Subset(Dataset(params), train_index)
val_subset = torch.utils.data.dataset.Subset(Dataset(params), val_index)
data_loaders['train'] = torch.utils.data.DataLoader(train_subset, ...)
data_loaders['val'] = data.pair_provider_subset(val_subset, ...)
参考:https://scikit-learn.org/stable/modules/cross_validation.html
https://stackoverflow.com/questions/60883696/k-fold-cross-validation-using-dataloaders-in-pytorch
版权声明:本文为qq_44761480原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。