-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconstruct_nl_data.py
142 lines (124 loc) · 6.41 KB
/
construct_nl_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# input are two lists: (1) groudtruth of mappedmovie IDs (not original IDs)
# (2) top 10 recommended items (movie ID before mapping back to original ID space)
import pickle
import pandas as pd
import re
import argparse
def convert_items_to_original_ids(processed_user_id, groundtruth_item_ids, top10_item_ids):
"""
Args:
processed_user_id: int
processed_item_ids: list
Returns:
original_user_id: int
original_item_ids: list
"""
with open('RippleNet-PyTorch-baseline/data/item_map.pkl', 'rb') as f:
item_map = pickle.load(f)
with open('RippleNet-PyTorch-baseline/data/user_map.pkl', 'rb') as f:
user_map = pickle.load(f)
item_new2old = {v: k for k, v in item_map.items()}
user_new2old = {v: k for k, v in user_map.items()}
original_gt_item_ids = [int(item_new2old[pid]) for pid in groundtruth_item_ids]
original_top10_item_ids = [int(item_new2old[pid]) for pid in top10_item_ids]
original_user_id = user_new2old[processed_user_id]
return original_user_id, original_gt_item_ids, original_top10_item_ids
def generate_sequential_dataset(input_data, input_path,
train_set_path='dataset/movielens_1M/train_set.csv',
seq_length=10):
user_id = input_data['user_id'].tolist()
groundtruth = input_data['groundtruth'].tolist()
top10_recommendations = input_data['top10_recommendations'].tolist()
#print(type(groundtruth[0]))
#print(type(top10_recommendations[0]))
#print(type)
# print(groundtruth[0])
train_set = pd.read_csv(train_set_path)
#print(groundtruth[0])
original_user_ids = [] # List[int]
original_gt_item_ids = []
original_top10_item_ids = [] # List[List[int]]
for uid, gt_list, top10_list in zip(user_id, groundtruth, top10_recommendations):
if "RippleNet-PyTorch-baseline" in input_path:
user_id_original, gt_items_original, top10_items_original = convert_items_to_original_ids(uid, eval(gt_list), eval(top10_list))
else:
user_id_original, gt_items_original, top10_items_original = uid, eval(gt_list), eval(top10_list)
original_user_ids.append(user_id_original)
original_gt_item_ids.append(gt_items_original)
original_top10_item_ids.append(top10_items_original)
#print(original_gt_item_ids[0])
dataset = []
for uid, gt_list, top10_list in zip(original_user_ids, original_gt_item_ids, original_top10_item_ids):
user_history = train_set[train_set['userId'] == uid]['movieId'].tolist()
# sort by timestamp
user_history.sort(key=lambda x: train_set[train_set['movieId'] == x]['timestamp'].values[0])
if len(user_history) > seq_length:
user_history = user_history[-seq_length:]
entry = {
'userId': uid,
'history': user_history,
'candidates': top10_list,
'groundtruth': gt_list
}
dataset.append(entry)
return pd.DataFrame(dataset)
def get_movie_description(row):
# Ensure genres is properly processed
genres = row['genres']
if isinstance(genres, str): # If genres are stored as a string, evaluate it to a list
genres = eval(genres)
if not isinstance(genres, list): # Ensure genres is a list
genres = [str(genres)]
return f"Movie {row['movieId']}: {row['title']} (Genres: {', '.join(genres)}; " \
f"Language: {row['original_language']}; Overview: {row['short_overview']})"
def process_sequence(row):
history = row['history']
history_movies = filtered_movies[filtered_movies['movieId'].isin(history)].apply(get_movie_description, axis=1).tolist()
candidates = row['candidates']
# print(type(candidates),type(candidates[0]),candidates)
candidate_movies = filtered_movies[filtered_movies['movieId'].isin(candidates)].apply(get_movie_description, axis=1).tolist()
next_items = row['groundtruth']
#print(type(next_items),type(next_items[0]),next_items)
ground_truth = filtered_movies[filtered_movies['movieId'].isin(next_items)].apply(get_movie_description, axis=1).tolist()
return pd.Series({
'user_id': row['userId'],
'history': ','.join(history_movies),
'candidates': ','.join(candidate_movies),
'ground_truth': ','.join(ground_truth)
})
def build_natural_language_dataset(seq_dataset,
big_table="dataset/movielens_1M/used_movies.csv",
n = 100):
"""
Args:
seq_dataset: pd.DataFrame, sequential dataset
big_table: str, path to the big table
n: int, number of samples to process (how large the dataset is for inference!!!)
"""
global filtered_movies
filtered_movies = pd.read_csv(big_table)
if n is not None:
natural_language_dataset = seq_dataset[:n].apply(process_sequence, axis=1)
else:
natural_language_dataset = seq_dataset.apply(process_sequence, axis=1)
return natural_language_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, default="dataset/movielens_1M/test_set.csv",
help='3 column dataset for inference: user_id, groundtruth, top10_recommendations')
parser.add_argument("--output_path", type=str, default="natural_language_dataset.csv",
help='natural language dataset for further inference')
parser.add_argument("--n", type=int, default=1000,
help='number of samples to process (how large the dataset is for inference!!!)')
args = parser.parse_args()
# read the input file
input_data = pd.read_csv(args.input_path)
seq_dataset = generate_sequential_dataset(input_data,args.input_path)
# Sort by the length of groundtruth and select top n
seq_dataset['groundtruth_length'] = seq_dataset['groundtruth'].apply(len) # Calculate groundtruth length
seq_dataset = seq_dataset.sort_values(by='groundtruth_length').head(args.n) # Sort and pick top n
max_groundtruth_length = seq_dataset['groundtruth_length'].max()
print(f"The maximum groundtruth length in the selected {args.n} samples is: {max_groundtruth_length}")
natural_language_dataset = build_natural_language_dataset(seq_dataset, n=args.n)
natural_language_dataset = natural_language_dataset.sort_values(by='user_id')
natural_language_dataset.to_csv(args.output_path, index=False)