From 55cfd84b17da412247e93f37cde21e7d566cc5a3 Mon Sep 17 00:00:00 2001 From: Calvin Che Date: Fri, 30 Aug 2024 13:55:18 +0800 Subject: [PATCH] Fix same min/max values --- pyproject.toml | 2 +- .../data_processing/scaler/min_max_scaler.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a5fc485..c32b335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synnax-shared" -version = "1.7.2" +version = "1.7.3" description = "Synnax shared Python pacakges" readme = "README.md" requires-python = ">=3.11" diff --git a/src/synnax_shared/data_processing/scaler/min_max_scaler.py b/src/synnax_shared/data_processing/scaler/min_max_scaler.py index bc824cc..5d51927 100644 --- a/src/synnax_shared/data_processing/scaler/min_max_scaler.py +++ b/src/synnax_shared/data_processing/scaler/min_max_scaler.py @@ -15,9 +15,15 @@ def __init__(self, min: float | None = None, max: float | None = None): self.max = max def fit_transform(self, ndarray: NDArray) -> NDArray: + if self.min is not None or self.max is not None: + raise ValueError("MinMaxScaler is already fitted") float_ndarray = ndarray.astype(float) - self.min = numpy.nanmin(float_ndarray) - self.max = numpy.nanmax(float_ndarray) + min = numpy.nanmin(float_ndarray) + max = numpy.nanmax(float_ndarray) + if min == max: + min = max - 1 + self.min = min + self.max = max return self.transform(float_ndarray) def transform(self, ndarray: NDArray) -> NDArray: