diff --git a/CATS/models/basemodel.py b/CATS/models/basemodel.py index f80a6f0..0400c82 100644 --- a/CATS/models/basemodel.py +++ b/CATS/models/basemodel.py @@ -1,10 +1,17 @@ -from typing import Callable, List, Literal, Tuple, Union + +import logging +import time +from typing import Callable, Dict, Iterable, List, Literal, Tuple, Union + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from sklearn.metrics import * +from tensorflow.keras.callbacks import Callback +from torch.utils.data import DataLoader, TensorDataset +from tqdm import tqdm from ..callbacks import History from ..inputs import (DenseFeat, SparseFeat, VarLenSparseFeat, @@ -12,6 +19,8 @@ embedding_lookup, get_dense_inputs) from ..layers import PredictionLayer +from tensorflow.python.keras.callbacks import CallbackList + class BaseModel(nn.Module): def __init__( @@ -72,6 +81,174 @@ def __init__( self.history = History() + def fit( + self, + x: Union[List[np.ndarray], Dict[str, np.ndarray]], + y: Union[np.ndarray, List[np.ndarray]], + batch_size: int = 256, + epochs: int = 1, + verbose: int = 1, + initial_epoch: int = 0, + validation_split: float = 0.0, + shuffle: bool = True, + callbacks: List[Callback] = None, + ) -> History: + """ + Training Model. Return history about training. + :param x: numpy array of training data (if the model has a single input), or list of numpy arrays (if the model + has multiple inputs). If input layers in the model are named, you can also pass a + dictionary mapping input names to numpy arrays. + :param y: numpy array of target (label) data or list of numpy arrays + :param batch_size: Integer. Number of sample per gradient update. + :param epochs: Integer. Number of epochs to train the model. + :param verbose: Integer. 0, 1, or 2. verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. + :param initial_epoch: Integer. Epoch at which to start training. + :param validation_split: Float between 0 and 1. rate of validation datasets. + :param shuffle: Bool. whether to shuffle the order of the batches at the beginning of each epoch. + :param callbacks: List of `deepctr_torch.callbacks.Callback` instances. [`EarlyStopping` , `ModelCheckpoint`] + :return: A `Histroy` object. Its `History.history` attribute is a record of training loss values and metrics + values at successive epochs, as well as validation loss values and validation metrics values (if applicable) + """ + # setting train & validation data + if isinstance(x, dict): + x = [x[feature] for feature in self.feature_index] + + do_validation = False + if validation_split and 0. < validation_split <= 1.0: + do_validation = True + if do_validation: + if hasattr(x[0], "shape"): + split_at = int(x[0].shape[0] * (1. - validation_split)) + else: + split_at = int(len(x[0]) * (1. - validation_split)) + x, val_x = [x_v[:split_at] for x_v in x], [x_v[split_at:] for x_v in x] + y, val_y = y[:split_at], y[split_at:] + y = np.asarray(y) + else: + val_x = [] + val_y = [] + + for i in range(len(x)): + if len(x[i].shape) == 1: + x[i] = np.expand_dims(x[i], axis=1) + + train_tensor_data = TensorDataset( + torch.from_numpy(np.concatenate(x, axis=-1)), torch.from_numpy(y) + ) + + model = self.train() + loss_func = self.loss_func + optim = self.optim + + # setting dataloader + train_loader = DataLoader( + dataset=train_tensor_data, shuffle=shuffle, batch_size=batch_size + ) + sample_num = len(train_tensor_data) + steps_per_epoch = (sample_num - 1) // batch_size + 1 + + # configure callbacks + callbacks = (callbacks or []) + [self.history] # add history callback + callbacks = CallbackList(callbacks) + callbacks.set_model(self) + callbacks.on_train_begin() + callbacks.set_model(self) + if not hasattr(callbacks, "model"): + callbacks.__setattr__("model", self) + callbacks.model.stop_training = False + + # Training + logging.info( + "Train on {0} samples, validate on {1} samples, {2} steps per epoch".format( + len(train_tensor_data), len(val_y), steps_per_epoch + ) + ) + + for epoch in range(initial_epoch, epochs): + callbacks.on_epoch_begin(epoch) + epoch_logs = {} + start_time = time.time() + loss_epoch = 0 + total_loss_epoch = 0 + train_result = {} + try: + with tqdm(enumerate(train_loader), disable=verbose != 1) as t: + for _, (x_train, y_train) in t: + x = x_train.to(self.device).float() + y = y_train.to(self.device).float() + + y_pred = model(x).squeeze() + + optim.zero_grad() + if isinstance(loss_func, list): + assert ( + len(loss_func) == self.num_tasks + ), "the length of `loss_func` should be equal with `self.num_tasks`" + loss = sum( + [ + loss_func[i](y_pred[:, i], y[:, i], reduction="sum") + for i in range(self.num_tasks) + ] + ) + else: + loss = loss_func(y_pred, y.squeeze(), reduction="sum") + reg_loss = self.get_regularization_loss() + + total_loss = loss + reg_loss + self.aux_loss + + loss_epoch += loss.item() + total_loss_epoch += total_loss.item() + total_loss.backward() + optim.step() + + if verbose > 0: + for name, metric_fun in self.metrics.items(): + if name not in train_result: + train_result[name] = [] + train_result[name].append( + metric_fun( + y.cpu().data.numpy(), + y_pred.cpu().data.numpy().astype("float64"), + ) + ) + except KeyboardInterrupt: + t.close() + raise + t.close() + + epoch_logs["loss"] = total_loss_epoch / sample_num + for name, result in train_result.items(): + epoch_logs[name] = np.sum(result) / steps_per_epoch + + # verbose + if verbose > 0: + epoch_time = int(time.time() - start_time) + logging.info("Epoch {0}/{1}".format(epoch + 1, epochs)) + + eval_str = "{0}s - loss: {1: .4f}".format( + epoch_time, epoch_logs["loss"] + ) + + for name in self.metrics: + eval_str += " - " + name + ": {0: .4f}".format(epoch_logs[name]) + + if do_validation: + for name in self.metrics: + eval_str += ( + " - " + + "val_" + + name + + ": {0: .4f}".format(epoch_logs["val_" + name]) + ) + logging.info(eval_str) + callbacks.on_epoch_end(epoch, epoch_logs) + if self.stop_training: + break + + callbacks.on_train_end() + + return self.history + def compile( self, optimizer: Union[ @@ -275,3 +452,43 @@ def input_from_feature_columns( ) return sparse_embedding_list, dense_value_list + + def add_regularization_weight( + self, + weight_list: Iterable[torch.nn.parameter.Parameter], + l1: float = 0.0, + l2: float = 0.0, + ): + """ + This function is used to add L1 and L2 regularization to the given set of weights. + :param weight_list: A list of parameters (weights) to which regularization will be added. + :param l1: The lambda value determining the strength of L1 regularization. + :param l2: The lambda value determining the strength of L2 regularization. + """ + weight_list = [weight_list] + self.regularization_weight.append((weight_list, l1, l2)) + + def get_regularization_loss(self) -> torch.Tensor: + """ + This function calculates and returns the total regularization loss for all the parameters + (weights) previously added through 'add_regularization_weight' method. + :return: torch.Tensor. The total regularization loss + """ + total_reg_loss = torch.zeros((1,), device=self.device) + for weight_list, l1, l2 in self.regularization_weight: + for w in weight_list: + if isinstance(w, tuple): + parameter = w[1] + else: + parameter = w + if l1 > 0: + total_reg_loss += torch.sum(l1 * torch.abs(parameter)) + + if l2 > 0: + try: + total_reg_loss += torch.sum(l2 * torch.square(parameter)) + except AttributeError: + total_reg_loss += torch.sum(l2 * parameter * parameter) + + return total_reg_loss + diff --git a/poetry.lock b/poetry.lock index d4db495..5e8b831 100644 --- a/poetry.lock +++ b/poetry.lock @@ -909,6 +909,27 @@ typing-extensions = "*" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] +[[package]] +name = "tqdm" +version = "4.67.1" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, + {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"] +discord = ["requests"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "typing-extensions" version = "3.7.4.3" @@ -984,4 +1005,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "1a37547bedf7ec5c1728a0d31636d602c105f2282f8159a822c334e94aeeecbf" +content-hash = "ac189dedb9c1deb9b077424a83446ec5f572ccb444e4d7dc07c2d8ff468659d1" diff --git a/pyproject.toml b/pyproject.toml index d481704..6e49217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ scikit-learn = "1.1.3" pandas = "1.1.5" keras = "2.6.0" protobuf = "3.20.1" +tqdm = "^4.67.1" [tool.poetry.group.dev.dependencies]