-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* created the environment and inserted the split and filter methods, in addition to refactoring the DatasetLoader.py and dataset.py modules * fixed random split * removed comments * refactor DefaultDatasetLoader.py and dataset.py, created folder for filters and splits * Set setup and make * Saving progress in the enviornment refactor * Implement loader * Finish enviornment implementation * Fix loader and split-base * Adding loaders to the iRec * Delete .idea directory * environment integration: environment/load environment/split and environment/filter * environment integration: environment/load environment/split and environment/filter * finished the integration of the fixed and updated train-test load registry.py * fixed num_total_users/items * fixed imports * Fix validation * A simple example of tests * fixed return train test dataset * added documentation for load module * fixed assert * Added docstrings and type hints * Added docstrings and typehints * Update docs for dataset.py and fixed warnings * Fix simple returns * Fix bugs * Add bdd tests * updated requirements * removed unit test * refactor: removed idea directory * refactor: removed unnecessary Makefile * fixed erros in yaml * Update InteractionMetricEvaluator.py * remove: traitlets dependency in run_agent * feat: dev requirements included behave * remove: redundant setup file * refactor: removed all app branch changes Co-authored-by: thiagodks <thiagoadriano2010@gmail.com> Co-authored-by: Nicollas Silva <ncsilvaa@Nicollass-MBP.lan> Co-authored-by: Thiago Silva <48692251+thiagodks@users.noreply.github.com> Co-authored-by: Carlos Mito <carlosmsmito@gmail.com> Co-authored-by: heitor57 <heitorwerneck@hotmail.com>
- Loading branch information
1 parent
51f499e
commit 74bdc28
Showing
31 changed files
with
988 additions
and
763 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ data/ | |
*.aux | ||
*.log | ||
*.csv | ||
.idea/ | ||
|
||
.vim/coc-settings.json | ||
# Byte-compiled / optimized / DLL files | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import List | ||
import numpy as np | ||
|
||
|
||
class Dataset: | ||
|
||
num_users = 0 | ||
num_items = 0 | ||
rate_domain = set() | ||
max_uid = 0 | ||
max_iid = 0 | ||
mean_rating = 0 | ||
min_rating = 0 | ||
max_rating = 0 | ||
|
||
def __init__( | ||
self, | ||
data: np.ndarray | ||
): | ||
"""__init__ | ||
Args: | ||
data (np.ndarray): the data | ||
""" | ||
self.data = data | ||
self.num_total_users = 0 | ||
self.num_total_items = 0 | ||
|
||
@staticmethod | ||
def normalize_ids(ids: List) -> np.array: | ||
"""normalize_ids | ||
normalizes the ids by putting them in sequence | ||
Args: | ||
ids (List): list of ids | ||
Returns: | ||
result (np.array): the normalized ids | ||
""" | ||
unique_values = np.sort(np.unique(ids)) | ||
result = np.searchsorted(unique_values, ids) | ||
return result | ||
|
||
def reset_index(self): | ||
"""reset_index | ||
Resets user and item indices | ||
""" | ||
self.data[:, 0] = self.normalize_ids(self.data[:, 0]) | ||
self.data[:, 1] = self.normalize_ids(self.data[:, 1]) | ||
|
||
def set_parameters(self): | ||
|
||
"""set_parameters | ||
Calculates and updates the database parameters | ||
""" | ||
self.num_users = len(np.unique(self.data[:, 0])) | ||
self.num_items = len(np.unique(self.data[:, 1])) | ||
self.rate_domain = set(np.unique(self.data[:, 2])) | ||
self.uids = np.unique(self.data[:, 0]).astype(int) | ||
self.iids = np.unique(self.data[:, 1]).astype(int) | ||
self.max_uid = np.max(self.uids) | ||
self.max_iid = np.max(self.iids) | ||
self.mean_rating = np.mean(self.data[:, 2]) | ||
self.min_rating = np.min(self.data[:, 2]) | ||
self.max_rating = np.max(self.data[:, 2]) | ||
|
||
def update_num_total_users_items(self, num_total_users=0, num_total_items=0): | ||
"""update_num_total_users_items | ||
Updates the total number of users and items | ||
""" | ||
self.num_total_users = num_total_users if num_total_users > self.max_uid+1 else self.max_uid+1 | ||
self.num_total_items = num_total_items if num_total_items > self.max_iid+1 else self.max_iid+1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import random | ||
from pandas import DataFrame | ||
|
||
|
||
class FilteringByItems: | ||
"""FilteringByItems. | ||
This class contains different filtering by item approaches. | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
@staticmethod | ||
def min_ratings(df_dataset: DataFrame, min_ratings: int) -> DataFrame: | ||
"""min_ratings. | ||
This function removes items whose total number of | ||
ratings is less than [min_ratings]. | ||
Args: | ||
df_dataset (DataFrame): the data to be filtered. | ||
min_ratings (int): minimum number of ratings. | ||
Returns: | ||
The data filtered by the number of ratings. | ||
""" | ||
selected_items = dict( | ||
df_dataset.groupby("itemId")["userId"].agg("count")[ | ||
lambda ratings: ratings >= min_ratings | ||
] | ||
) | ||
return df_dataset[df_dataset["itemId"].isin(selected_items)] | ||
|
||
@staticmethod | ||
def num_items(df_dataset: DataFrame, num_items: int) -> DataFrame: | ||
"""num_items. | ||
This function limits the number of distinct items in the dataset. | ||
Args: | ||
df_dataset (DataFrame): the data to be filtered. | ||
num_items (int): maximum number of items in the dataset. | ||
Returns: | ||
The data filtered by the number of items. | ||
""" | ||
try: | ||
selected_items = random.sample(list(df_dataset["itemId"].unique()), num_items) | ||
except: | ||
return df_dataset | ||
return df_dataset[df_dataset["itemId"].isin(selected_items)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import random | ||
from pandas import DataFrame | ||
|
||
|
||
class FilteringByUsers: | ||
"""FilteringByUsers. | ||
This class contains different filtering by users approaches. | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
@staticmethod | ||
def min_consumption(df_dataset: DataFrame, min_consumption: int) -> DataFrame: | ||
"""min_consumption. | ||
This function removes users whose total number of | ||
consumptions is less than [min_consumption]. | ||
Args: | ||
df_dataset (DataFrame): the data to be filtered. | ||
min_consumption (int): minimum number of items consumed by a user. | ||
Returns: | ||
The data filtered by the number of consumptions. | ||
""" | ||
selected_users = dict( | ||
df_dataset.groupby("userId")["itemId"].agg("count")[ | ||
lambda consumption: consumption >= min_consumption | ||
] | ||
) | ||
return df_dataset[df_dataset["userId"].isin(selected_users)] | ||
|
||
@staticmethod | ||
def num_users(df_dataset: DataFrame, num_users: int) -> DataFrame: | ||
"""num_users. | ||
This function limits the number of distinct users in the dataset. | ||
Args: | ||
df_dataset (DataFrame): the data to be filtered. | ||
num_users (int): maximum number of users in the dataset. | ||
Returns: | ||
The data filtered by the number of users. | ||
""" | ||
try: | ||
selected_users = random.sample(list(df_dataset["userId"].unique()), num_users) | ||
except: | ||
return df_dataset | ||
return df_dataset[df_dataset["userId"].isin(selected_users)] |
Oops, something went wrong.