-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmovielens_dataset.py
32 lines (26 loc) · 1.11 KB
/
movielens_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
class MovieLensDataset(torch.utils.data.Dataset):
"""
MovieLens 20M Dataset
Data preparation
treat samples with a rating less than 3 as negative samples
:param dataset_path: MovieLens dataset path
Reference:
https://grouplens.org/datasets/movielens
"""
def __init__(self, dataset, sep=',', engine='c', header='infer'):
# data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=header).to_numpy()[:, :3]
data = dataset.to_numpy()[:, :3]
self.items = data[:, :2].astype(np.int) # -1 because ID begins from 1
self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32)
self.field_dims = np.max(self.items, axis=0) + 1
self.user_field_idx = np.array((0, ), dtype=np.long)
self.item_field_idx = np.array((1,), dtype=np.long)
def __len__(self):
return self.targets.shape[0]
def __getitem__(self, index):
return self.items[index], self.targets[index]
def __preprocess_target(self, target):
target[target <= 3] = 0
target[target > 3] = 1
return target