在 PyTorch 中使用 DataLoader 进行 k 折交叉验证

新手上路,请多包涵

我已将训练数据集拆分为 80% 的训练数据和 20% 的验证数据,并创建了 DataLoader,如下所示。但是我不想限制我的模型训练。所以我想到将我的数据分成 K(也许 5)份并执行交叉验证。但是我不知道如何在拆分数据集后将它们组合到我的数据加载器中。

 train_size = int(0.8 * len(full_dataset))
validation_size = len(full_dataset) - train_size
train_dataset, validation_dataset = random_split(full_dataset, [train_size, validation_size])

full_loader = DataLoader(full_dataset, batch_size=4,sampler = sampler_(full_dataset), pin_memory=True)
train_loader = DataLoader(train_dataset, batch_size=4, sampler = sampler_(train_dataset))
val_loader = DataLoader(validation_dataset, batch_size=1, sampler = sampler_(validation_dataset))

原文由 Suraj 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1k
1 个回答

我刚刚编写了一个与数据加载器和数据集一起使用的交叉验证函数。这是我的代码,希望这对您有所帮助。

 # define a cross validation function
def crossvalid(model=None,criterion=None,optimizer=None,dataset=None,k_fold=5):

    train_score = pd.Series()
    val_score = pd.Series()

    total_size = len(dataset)
    fraction = 1/k_fold
    seg = int(total_size * fraction)
    # tr:train,val:valid; r:right,l:left;  eg: trrr: right index of right side train subset
    # index: [trll,trlr],[vall,valr],[trrl,trrr]
    for i in range(k_fold):
        trll = 0
        trlr = i * seg
        vall = trlr
        valr = i * seg + seg
        trrl = valr
        trrr = total_size
        # msg
#         print("train indices: [%d,%d),[%d,%d), test indices: [%d,%d)"
#               % (trll,trlr,trrl,trrr,vall,valr))

        train_left_indices = list(range(trll,trlr))
        train_right_indices = list(range(trrl,trrr))

        train_indices = train_left_indices + train_right_indices
        val_indices = list(range(vall,valr))

        train_set = torch.utils.data.dataset.Subset(dataset,train_indices)
        val_set = torch.utils.data.dataset.Subset(dataset,val_indices)

#         print(len(train_set),len(val_set))
#         print()

        train_loader = torch.utils.data.DataLoader(train_set, batch_size=50,
                                          shuffle=True, num_workers=4)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=50,
                                          shuffle=True, num_workers=4)
        train_acc = train(res_model,criterion,optimizer,train_loader,epoch=1)
        train_score.at[i] = train_acc
        val_acc = valid(res_model,criterion,optimizer,val_loader)
        val_score.at[i] = val_acc

    return train_score,val_score


train_score,val_score = crossvalid(res_model,criterion,optimizer,dataset=tiny_dataset)


为了直观地了解我们正在做的事情的正确性,请参阅下面的输出:

 train indices: [0,0),[3600,18000), test indices: [0,3600)
14400 3600

train indices: [0,3600),[7200,18000), test indices: [3600,7200)
14400 3600

train indices: [0,7200),[10800,18000), test indices: [7200,10800)
14400 3600

train indices: [0,10800),[14400,18000), test indices: [10800,14400)
14400 3600

train indices: [0,14400),[18000,18000), test indices: [14400,18000)
14400 3600

原文由 Skipper 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题
logo
Stack Overflow 翻译
子站问答
访问
宣传栏