我有一个如下所示的数据集。那就是第一项是用户 ID,然后是用户单击的项目集。
0 24104 27359 6684
0 24104 27359
1 16742 31529 31485
1 16742 31529
2 6579 19316 13091 7181 6579 19316 13091
2 6579 19316 13091 7181 6579 19316
2 6579 19316 13091 7181 6579 19316 13091 6579
2 6579 19316 13091 7181 6579
4 19577 21608
4 19577 21608
4 19577 21608 18373
5 3541 9529
5 3541 9529
6 6832 19218 14144
6 6832 19218
7 9751 23424 25067 12606 26245 23083 12606
我定义了一个自定义数据集来处理我的点击日志数据。
import torch.utils.data as data
class ClickLogDataset(data.Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.uids = []
self.streams = []
with open(self.data_path, 'r') as fdata:
for row in fdata:
row = row.strip('\n').split('\t')
self.uids.append(int(row[0]))
self.streams.append(list(map(int, row[1:])))
def __len__(self):
return len(self.uids)
def __getitem__(self, idx):
uid, stream = self.uids[idx], self.streams[idx]
return uid, stream
然后我使用 DataLoader 从数据中检索小批量进行训练。
from torch.utils.data.dataloader import DataLoader
clicklog_dataset = ClickLogDataset(data_path)
clicklog_data_loader = DataLoader(dataset=clicklog_dataset, batch_size=16)
for uid_batch, stream_batch in stream_data_loader:
print(uid_batch)
print(stream_batch)
上面的代码返回的结果与我预期的不同,我希望 stream_batch
是长度为整数类型的二维张量 16
。然而,我得到的是一个长度为 16 的一维张量列表,并且该列表只有一个元素,如下所示。这是为什么 ?
#stream_batch
[tensor([24104, 24104, 16742, 16742, 6579, 6579, 6579, 6579, 19577, 19577,
19577, 3541, 3541, 6832, 6832, 9751])]
原文由 Trung Le 发布,翻译遵循 CC BY-SA 4.0 许可协议
那么你如何处理样本长度不同的事实呢?
torch.utils.data.DataLoader
有一个collate_fn
参数,用于将样本列表转换为批次。 默认情况下,它会对列表执行 此 操作。您可以编写自己的collate_fn
,例如0
填充输入,将其截断为某个预定义的长度或应用您选择的任何其他操作。