diff --git a/.github/workflows/run-tests-cpu.yaml b/.github/workflows/run-tests-cpu.yaml index bf66900..6a36d09 100644 --- a/.github/workflows/run-tests-cpu.yaml +++ b/.github/workflows/run-tests-cpu.yaml @@ -32,16 +32,16 @@ jobs: run: uv sync --frozen --extra cpu - name: Run | Tests -> unit - run: uv run pytest tests/unit + run: uv run --no-sync pytest tests/unit - name: Build mkdocs - run: uv run mkdocs build --strict + run: uv run --no-sync mkdocs build --strict - name: Run tests -> end_to_end -> sequential - run: uv run pytest tests/end_to_end/test_tabular_sequential.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py - name: Run tests -> end_to_end -> sequential context - run: uv run pytest tests/end_to_end/test_tabular_sequential_context.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py run-tests-cpu-end-to-end-nonsequential: runs-on: ubuntu-latest @@ -66,4 +66,4 @@ jobs: run: uv sync --frozen --extra cpu --no-group docs - name: Run tests -> end_to_end all except sequential - run: uv run pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ + run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ diff --git a/.github/workflows/run-tests-gpu.yaml b/.github/workflows/run-tests-gpu.yaml index 3958cbc..269f8a9 100644 --- a/.github/workflows/run-tests-gpu.yaml +++ b/.github/workflows/run-tests-gpu.yaml @@ -42,10 +42,10 @@ jobs: run: nvidia-smi - name: Run tests -> end_to_end -> sequential - run: uv run pytest tests/end_to_end/test_tabular_sequential.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py - name: Run tests -> end_to_end -> sequential context - run: uv run pytest tests/end_to_end/test_tabular_sequential_context.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py - name: Run tests -> end_to_end all except sequential - run: uv run pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ + run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ diff --git a/Makefile b/Makefile index 853ab26..9fc72a1 100644 --- a/Makefile +++ b/Makefile @@ -12,11 +12,11 @@ install: # Install dependencies .PHONY: lint lint: ## Run lints - uv run pre-commit run --all-files + uv run --no-sync pre-commit run --all-files .PHONY: test test: ## Run tests - uv run pytest + uv run --no-sync pytest .PHONY: all all: clean install lint test ## Run the commands: clean install lint test diff --git a/mostlyai/engine/_encoding_types/language/categorical.py b/mostlyai/engine/_encoding_types/language/categorical.py new file mode 100644 index 0000000..aebfc39 --- /dev/null +++ b/mostlyai/engine/_encoding_types/language/categorical.py @@ -0,0 +1,75 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Categorical encoding for language models. +""" + +import numpy as np +import pandas as pd + +from mostlyai.engine._common import safe_convert_string, STRING + +CATEGORICAL_UNKNOWN_TOKEN = "_RARE_" + + +def analyze_language_categorical(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: + values = safe_convert_string(values) + # count distinct root_keys per categorical value for rare-category protection + df = pd.concat([root_keys, values], axis=1) + cnt_values = df.groupby(values.name)[root_keys.name].nunique().to_dict() + stats = {"has_nan": sum(values.isna()) > 0, "cnt_values": cnt_values} + return stats + + +def analyze_reduce_language_categorical(stats_list: list[dict], value_protection: bool = True) -> dict: + # sum up all counts for each categorical value + cnt_values: dict[str, int] = {} + for item in stats_list: + for value, count in item["cnt_values"].items(): + cnt_values[value] = cnt_values.get(value, 0) + count + # create alphabetically sorted list of non-rare categories + known_categories = [k for k in sorted(cnt_values.keys())] + if value_protection: + # stochastic threshold for rare categories + rare_min = 5 + int(3 * np.random.uniform()) + else: + rare_min = 0 + categories = [k for k in known_categories if cnt_values[k] >= rare_min] + no_of_rare_categories = len(known_categories) - len(categories) + # add None to categories, if any are present + if any([j["has_nan"] for j in stats_list]): + categories = [None] + categories + # add special token for UNKNOWN categories at first position + if no_of_rare_categories > 0: + categories = [CATEGORICAL_UNKNOWN_TOKEN] + categories + stats = {"no_of_rare_categories": no_of_rare_categories, "categories": categories} + return stats + + +def encode_language_categorical(values: pd.Series, stats: dict) -> pd.Series: + values = safe_convert_string(values) + values = values.copy() + known_categories = stats["categories"] + mask = ~values.isin(known_categories) + if None in known_categories: + mask &= ~pd.isna(values) + values[mask] = CATEGORICAL_UNKNOWN_TOKEN + return values + + +def decode_language_categorical(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + x = x.astype(STRING) + allowed_categories = col_stats.get("categories", []) + return x.where(x.isin(allowed_categories), other=None) diff --git a/mostlyai/engine/_encoding_types/language/datetime.py b/mostlyai/engine/_encoding_types/language/datetime.py new file mode 100644 index 0000000..c7a1b37 --- /dev/null +++ b/mostlyai/engine/_encoding_types/language/datetime.py @@ -0,0 +1,143 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import calendar + +import numpy as np +import pandas as pd + +from mostlyai.engine._common import safe_convert_datetime + + +def analyze_language_datetime(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: + values = safe_convert_datetime(values) + df = pd.concat([root_keys, values], axis=1) + # determine lowest/highest values by root ID, and return Top 10 + min_dates = df.groupby(root_keys.name)[values.name].min().dropna() + min11 = min_dates.sort_values(ascending=True).head(11).astype(str).tolist() + max_dates = df.groupby(root_keys.name)[values.name].max().dropna() + max11 = max_dates.sort_values(ascending=False).head(11).astype(str).tolist() + # determine if there are any NaN values + has_nan = bool(values.isna().any()) + # return stats + stats = { + "has_nan": has_nan, + "min11": min11, + "max11": max11, + } + return stats + + +def analyze_reduce_language_datetime(stats_list: list[dict], value_protection: bool = True) -> dict: + # check if there are missing values + has_nan = any([j["has_nan"] for j in stats_list]) + # determine min / max 5 values to map too low / too high values to + min11 = sorted([v for min11 in [j["min11"] for j in stats_list] for v in min11], reverse=False)[:11] + max11 = sorted([v for max11 in [j["max11"] for j in stats_list] for v in max11], reverse=True)[:11] + if value_protection: + # extreme value protection - discard lowest/highest 5 values + if len(min11) < 11 or len(max11) < 11: + # less than 11 subjects with non-NULL values; we need to protect all + min5 = [] + max5 = [] + else: + min5 = [str(v) for v in min11[5:10]] # drop 1 to 5th lowest; keep 6th to 10th lowest + max5 = [str(v) for v in max11[5:10]] # drop 1 to 5th highest; keep 6th to 10th highest + else: + min5 = min11[0:4] + max5 = max11[0:4] + stats = { + "has_nan": has_nan, + "min5": min5, + "max5": max5, + } + return stats + + +def encode_language_datetime(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.Series: + # convert + values = safe_convert_datetime(values) + values = values.copy() + # reset index, as `values.mask` can throw errors for misaligned indices + values.reset_index(drop=True, inplace=True) + # replace extreme values with randomly sampled 5-th to 10-th largest/smallest values + min5 = stats["min5"] if len(stats["min5"]) > 0 else [0] + max5 = stats["max5"] if len(stats["max5"]) > 0 else [0] + min5 = pd.Series(min5, dtype=values.dtype) + max5 = pd.Series(max5, dtype=values.dtype) + values.mask( + values < min5[0], + min5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + values.mask( + values > max5[0], + max5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + return values + + +def _clip_datetime(x: pd.Series, min5: list, max5: list) -> pd.Series: + x_dt = pd.to_datetime(x, errors="coerce") + min_arr = pd.to_datetime(min5).to_numpy(dtype="datetime64[ns]") + max_arr = pd.to_datetime(max5).to_numpy(dtype="datetime64[ns]") + n = len(x_dt) + random_mins = np.random.choice(min_arr, size=n) + random_maxs = np.random.choice(max_arr, size=n) + clipped = np.minimum(np.maximum(x_dt.to_numpy(dtype="datetime64[ns]"), random_mins), random_maxs) + return pd.Series(clipped, index=x.index) + + +def decode_language_datetime(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + x = x.where(~x.isin(["", "_INVALID_"]), np.nan) + + valid_mask = ( + x.str.len().ge(10) + & x.str.slice(0, 4).str.isdigit() + & x.str.slice(5, 7).str.isdigit() + & x.str.slice(8, 10).str.isdigit() + ) + if valid_mask.sum() > 0: # expected "YYYY-MM-DD" prefix + # handle the date portion, ensuring validity + years = x[valid_mask].str.slice(0, 4).astype(int) + months = x[valid_mask].str.slice(5, 7).astype(int) + days = x[valid_mask].str.slice(8, 10).astype(int) + + # clamp days according to maximum possible day of the month of a given year + last_days = np.array([calendar.monthrange(y, m)[1] for y, m in zip(years, months)]) + clamped_days = np.minimum(days, last_days) + + # rebuild the date portion + new_date = ( + years.astype(str).str.zfill(4) + + "-" + + months.astype(str).str.zfill(2) + + "-" + + pd.Series(clamped_days, index=years.index).astype(str).str.zfill(2) + ) + + # handle the time portion, ensuring validity + remainder = x[valid_mask].str.slice(10) + + time_regex = r"^[ T]?(\d{2}:\d{2}:\d{2}(?:\.\d+)?)" + valid_time = remainder.str.extract(time_regex, expand=False) + valid_time = valid_time.fillna("00:00:00") + valid_time = " " + valid_time + + new_date = new_date + valid_time + x.loc[valid_mask] = new_date + + x = pd.to_datetime(x, errors="coerce") + x = _clip_datetime(x, col_stats["min5"], col_stats["max5"]) + return x.astype("datetime64[ns]") diff --git a/mostlyai/engine/_encoding_types/language/numeric.py b/mostlyai/engine/_encoding_types/language/numeric.py new file mode 100644 index 0000000..a3723b4 --- /dev/null +++ b/mostlyai/engine/_encoding_types/language/numeric.py @@ -0,0 +1,136 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pandas as pd + +from mostlyai.engine._common import safe_convert_numeric +from mostlyai.engine._encoding_types.tabular.numeric import _type_safe_numeric_series +from mostlyai.engine.domain import ModelEncodingType + + +def analyze_language_numeric(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: + values = safe_convert_numeric(values) + + # determine lowest/highest values by root ID, and return top 11 + df = pd.concat([root_keys, values], axis=1) + min_values = df.groupby(root_keys.name)[values.name].min().dropna() + min11 = min_values.sort_values(ascending=True).head(11).tolist() + max_values = df.groupby(root_keys.name)[values.name].max().dropna() + max11 = max_values.sort_values(ascending=False).head(11).tolist() + + # determine if there are any NaN values + has_nan = bool(values.isna().any()) + + # determine max scale + def count_scale(num: float) -> int: + # represent number as fixed point string, remove trailing zeros and decimal point + num = format(num, "f").rstrip("0").rstrip(".") + if "." in num: + # in case of decimal, return number of digits after decimal point + return len(num.split(".")[1]) + # in case of integer, return 0 + return 0 + + max_scale = int(values.apply(count_scale).max()) + + stats = { + "has_nan": has_nan, + "max_scale": max_scale, + "min11": min11, + "max11": max11, + } + return stats + + +def analyze_reduce_language_numeric(stats_list: list[dict], value_protection: bool = True) -> dict: + # check for occurrence of NaN values + has_nan = any([j["has_nan"] for j in stats_list]) + + # determine max scale + max_scale = max([j["max_scale"] for j in stats_list]) + + # determine min / max 5 values to map too low / too high values to + min11 = sorted([v for min11 in [j["min11"] for j in stats_list] for v in min11], reverse=False)[:11] + max11 = sorted([v for max11 in [j["max11"] for j in stats_list] for v in max11], reverse=True)[:11] + if value_protection: + # extreme value protection - discard lowest/highest 5 values + if len(min11) < 11 or len(max11) < 11: + # less than 11 subjects with non-NULL values; we need to protect all + min5 = [] + max5 = [] + else: + min5 = min11[5:10] # drop 1 to 5th lowest; keep 6th to 10th lowest + max5 = max11[5:10] # drop 1 to 5th highest; keep 6th to 10th highest + else: + min5 = min11[0:5] + max5 = max11[0:5] + + stats = { + "encoding_type": ModelEncodingType.language_numeric.value, + "has_nan": has_nan, + "max_scale": max_scale, + "min5": min5, + "max5": max5, + } + + return stats + + +def encode_language_numeric(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: + values = safe_convert_numeric(values) + # try to convert to int, if possible + dtype = "Int64" if stats["max_scale"] == 0 else "Float64" + if dtype == "Int64": + values = values.round() + try: + values = values.astype(dtype) + except TypeError: + if dtype == "Int64": # if couldn't safely convert to int, stick to float + dtype = "Float64" + values = values.astype(dtype) + # reset index, as `values.mask` can throw errors for misaligned indices + values.reset_index(drop=True, inplace=True) + # replace extreme values with randomly sampled 5-th to 10-th largest/smallest values + min5 = _type_safe_numeric_series(stats["min5"] or [0], dtype) + max5 = _type_safe_numeric_series(stats["max5"] or [0], dtype) + values.mask( + values < min5[0], + min5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + values.mask( + values > max5[0], + max5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + return values + + +def _clip_numeric(x: pd.Series, min5: list, max5: list) -> pd.Series: + x_numeric = pd.to_numeric(x, errors="coerce") + min_arr = np.array(min5, dtype=x_numeric.dtype) + max_arr = np.array(max5, dtype=x_numeric.dtype) + n = len(x_numeric) + random_mins = np.random.choice(min_arr, size=n) + random_maxs = np.random.choice(max_arr, size=n) + clipped = np.minimum(np.maximum(x_numeric.to_numpy(), random_mins), random_maxs) + return pd.Series(clipped, index=x.index) + + +def decode_language_numeric(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + x = pd.to_numeric(x, errors="coerce") + x = x.round(col_stats["max_scale"]) + x = _clip_numeric(x, col_stats["min5"], col_stats["max5"]) + dtype = "Int64" if col_stats["max_scale"] == 0 else float + return x.astype(dtype) diff --git a/mostlyai/engine/_encoding_types/language/text.py b/mostlyai/engine/_encoding_types/language/text.py index 2e89b75..245699a 100644 --- a/mostlyai/engine/_encoding_types/language/text.py +++ b/mostlyai/engine/_encoding_types/language/text.py @@ -14,7 +14,7 @@ import pandas as pd -from mostlyai.engine._common import safe_convert_string +from mostlyai.engine._common import safe_convert_string, STRING def analyze_text(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: @@ -39,3 +39,7 @@ def analyze_reduce_text(stats_list: list[dict], _: bool = True) -> dict: "nchar_max": nchar_max, } return stats + + +def decode_text(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + return x.astype(STRING) diff --git a/mostlyai/engine/_encoding_types/tabular/numeric.py b/mostlyai/engine/_encoding_types/tabular/numeric.py index 44cf58c..8edff93 100644 --- a/mostlyai/engine/_encoding_types/tabular/numeric.py +++ b/mostlyai/engine/_encoding_types/tabular/numeric.py @@ -165,7 +165,7 @@ def analyze_numeric( # do not count values, if there are too many cnt_values = None - # determine lowest/highest values by root ID, and return Top 10 + # determine lowest/highest values by root ID, and return top 11 df = pd.concat([root_keys, values], axis=1) min_values = df.groupby(root_keys.name)[values.name].min().dropna() min11 = min_values.sort_values(ascending=True).head(11).astype("float").tolist() diff --git a/mostlyai/engine/_language/encoding.py b/mostlyai/engine/_language/encoding.py index cb74ada..752624f 100644 --- a/mostlyai/engine/_language/encoding.py +++ b/mostlyai/engine/_language/encoding.py @@ -24,17 +24,36 @@ from mostlyai.engine._common import is_sequential, ProgressCallback, ProgressCallbackWrapper, TABLE_COLUMN_INFIX from mostlyai.engine._workspace import ensure_workspace_dir, Workspace, reset_dir +from mostlyai.engine._encoding_types.language.categorical import encode_language_categorical +from mostlyai.engine._encoding_types.language.numeric import encode_language_numeric +from mostlyai.engine._encoding_types.language.datetime import encode_language_datetime _LOG = logging.getLogger(__name__) -def format_df(df: pd.DataFrame, columns: list[str], is_target: bool = False) -> pd.DataFrame: - df = df[columns].copy() +def apply_encoding_types(df: pd.DataFrame, stats: dict) -> pd.DataFrame: + for col, col_stats in stats["columns"].items(): + if col_stats["encoding_type"] == "LANGUAGE_CATEGORICAL": + df[col] = encode_language_categorical(df[col], col_stats) + elif col_stats["encoding_type"] == "LANGUAGE_NUMERIC": + df[col] = encode_language_numeric(df[col], col_stats) + elif col_stats["encoding_type"] == "LANGUAGE_DATETIME": + df[col] = encode_language_datetime(df[col], col_stats) + return df + + +def drop_sequential_columns(df: pd.DataFrame) -> pd.DataFrame: # Some columns (e.g., SCP columns) may contain np.ndarray, which are not JSON serializable # We need to drop them before converting the DataFrame to JSON sequential_columns = [col for col in df.columns if is_sequential(df[col])] df = df.drop(columns=sequential_columns) - _LOG.info(f"Formatting {'target' if is_target else 'context'} columns {df.columns.tolist()} to JSON") + return df + + +def format_df(df: pd.DataFrame, stats: dict, is_target: bool = False) -> pd.DataFrame: + columns = list(stats["columns"].keys()) + df = df[columns].copy() + _LOG.info(f"Formatting {'target' if is_target else 'context'} columns {columns} to JSON") # convert date format to ISO so that it's JSON serializable for col in df.columns: if is_datetime64_any_dtype(df[col]): @@ -76,15 +95,21 @@ def row_to_json(row: pd.Series, is_target: bool = False) -> str: def encode_df( ctx_df: pd.DataFrame, - ctx_columns: list[str], + ctx_stats: dict | None = None, tgt_df: pd.DataFrame | None = None, - tgt_columns: list[str] | None = None, + tgt_stats: dict | None = None, ) -> pd.DataFrame: - assert (tgt_df is None) == (tgt_columns is None), "tgt_df and tgt_columns must be both None or both not None" + assert (tgt_df is None) == (tgt_stats is None), "tgt_df and tgt_stats must be both None or both not None" + if ctx_stats is None: + ctx_stats = {"columns": {}} df = pd.DataFrame() - df["ctx"] = format_df(ctx_df, columns=ctx_columns, is_target=False) - if tgt_df is not None and tgt_columns is not None: - df["tgt"] = format_df(tgt_df, columns=tgt_columns, is_target=True) + ctx_df = drop_sequential_columns(ctx_df) + ctx_df = apply_encoding_types(ctx_df, stats=ctx_stats) + df["ctx"] = format_df(ctx_df, stats=ctx_stats, is_target=False) + if tgt_df is not None and tgt_stats is not None: + tgt_df = drop_sequential_columns(tgt_df) + tgt_df = apply_encoding_types(tgt_df, stats=tgt_stats) + df["tgt"] = format_df(tgt_df, stats=tgt_stats, is_target=True) # log the bounds of n_tokens in this partition content = df["ctx"] + df["tgt"] if "tgt" in df.columns else df["ctx"] @@ -107,19 +132,16 @@ def _encode_partition( ctx_stats: dict | None = None, ) -> None: tgt_df = pd.read_parquet(tgt_partition_file) - tgt_columns = list(tgt_stats.get("columns", {}).keys()) if ctx_partition_file: ctx_df = pd.read_parquet(ctx_partition_file) - ctx_columns = list(ctx_stats.get("columns", {}).keys()) else: # create on-the-fly context ctx_df = pd.DataFrame(index=range(len(tgt_df))) - ctx_columns = [] df = encode_df( ctx_df=ctx_df, - ctx_columns=ctx_columns, + ctx_stats=ctx_stats, tgt_df=tgt_df, - tgt_columns=tgt_columns, + tgt_stats=tgt_stats, ) # shuffle and persist to disk as parquet files df = df.sample(frac=1) diff --git a/mostlyai/engine/_language/engine/hf_engine.py b/mostlyai/engine/_language/engine/hf_engine.py index 18c3c5f..88f30a2 100644 --- a/mostlyai/engine/_language/engine/hf_engine.py +++ b/mostlyai/engine/_language/engine/hf_engine.py @@ -23,8 +23,8 @@ from transformers import AutoTokenizer from mostlyai.engine._language.common import load_base_model_and_config -from mostlyai.engine._language.tokenizer_utils import tokenize_fn from mostlyai.engine._language.formatron_utils import monkey_patch_formatron +from mostlyai.engine._language.tokenizer_utils import tokenize_fn from mostlyai.engine._language.engine.base import EngineMetrics, LanguageEngine from formatron.formatter import FormatterBuilder @@ -66,11 +66,8 @@ def __init__( self.tokenizer.special_tokens_map ) self._json_enforcing_possible = is_peft_adapter or is_trained_lstm_tokenizer - - # apply all necessary monkey patches to the formatron library - if self._json_enforcing_possible: + if self.supports_json_enforcing(): monkey_patch_formatron() - self._logits_processors = None def get_default_batch_size(self) -> int: diff --git a/mostlyai/engine/_language/engine/vllm_engine.py b/mostlyai/engine/_language/engine/vllm_engine.py index 247479a..afa1268 100644 --- a/mostlyai/engine/_language/engine/vllm_engine.py +++ b/mostlyai/engine/_language/engine/vllm_engine.py @@ -26,6 +26,7 @@ from peft import PeftConfig from transformers import AutoTokenizer, AutoConfig, PreTrainedTokenizerBase +from mostlyai.engine._language.formatron_utils import monkey_patch_formatron from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest from vllm.config import _get_and_verify_max_len @@ -123,6 +124,7 @@ def __init__( add_eos_token=False, ) self._logits_processors = None + monkey_patch_formatron() def get_default_batch_size(self) -> int: return 192 diff --git a/mostlyai/engine/_language/formatron_utils.py b/mostlyai/engine/_language/formatron_utils.py index d6ef79f..ec6df8b 100644 --- a/mostlyai/engine/_language/formatron_utils.py +++ b/mostlyai/engine/_language/formatron_utils.py @@ -12,17 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. + import typing + import pandas as pd from formatron.schemas.pydantic import ClassSchema from json import JSONDecodeError -from pydantic import ValidationError +from pydantic import Field, SkipValidation, ValidationError from formatron.formatter import FormatterBuilder -from typing import Literal +from formatron import schemas from formatron.formats import json +from typing import Literal from pydantic import create_model from transformers import PreTrainedTokenizerBase +from mostlyai.engine._encoding_types.language.categorical import CATEGORICAL_UNKNOWN_TOKEN +from mostlyai.engine.domain import ModelEncodingType, RareCategoryReplacementMethod JSON_NULL = "null" @@ -41,41 +46,56 @@ def transform(x: str | None) -> str: return sample_seed.astype("string[pyarrow]").map(transform) -def monkey_patch_formatron(): - # alter the Grammar of formatron's json schema - FORMATRON_WHITESPACE_MAX_REPETITIONS = 10 - SPACE_NONTERMINAL = f"[ \t\n\r]{{0,{FORMATRON_WHITESPACE_MAX_REPETITIONS}}}" - - json.GRAMMAR_HEADER = rf"""integer ::= #"-?(0|[1-9]\\d*)"; - number ::= #"-?(0|[1-9]\\d*)(\\.\\d+)?([eE][+-]?\\d+)?"; - string ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}})*"'; - boolean ::= "true"|"false"; - null ::= "null"; - array ::= array_begin (json_value (comma json_value)*)? array_end; - object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end; - json_value ::= number|string|boolean|null|array|object; - comma ::= #"{SPACE_NONTERMINAL},{SPACE_NONTERMINAL}"; - colon ::= #"{SPACE_NONTERMINAL}:{SPACE_NONTERMINAL}"; - object_begin ::= #" \\{{{SPACE_NONTERMINAL}"; - object_end ::= #"{SPACE_NONTERMINAL}\\}}"; - array_begin ::= #"\\[{SPACE_NONTERMINAL}"; - array_end ::= #"{SPACE_NONTERMINAL}\\]"; - """ - - def get_formatter_builders( - *, seed_df: pd.DataFrame | None = None, size: int | None = None, unseeded_fields: list[str] + *, + seed_df: pd.DataFrame | None = None, + size: int | None = None, + stats: dict, + rare_category_replacement_method: RareCategoryReplacementMethod, ) -> list[FormatterBuilder]: assert (seed_df is not None) ^ (size is not None), "exactly one of seed_df or size must be provided" formatter_builders = [] if seed_df is None: seed_df = pd.DataFrame(index=range(size)) + unseeded_fields = [c for c in list(stats["columns"].keys()) if c not in seed_df.columns.to_list()] + field_types = { + t: [col for col, col_stats in stats["columns"].items() if col_stats["encoding_type"] == t] + for t in ModelEncodingType + } + categorical_fields = field_types.get(ModelEncodingType.language_categorical, []) + numeric_fields = field_types.get(ModelEncodingType.language_numeric, []) + datetime_fields = field_types.get(ModelEncodingType.language_datetime, []) for _, seed_row in seed_df.iterrows(): formatter_builder = FormatterBuilder() model_dict = {} if not seed_row.empty: - model_dict |= {field_name: (Literal[seed_value], ...) for field_name, seed_value in seed_row.items()} - model_dict |= {field_name: (str, ...) for field_name in unseeded_fields} + model_dict |= {field_name: (Literal[seed_value], ...) for field_name, seed_value in seed_row.items()} # type: ignore[valid-type] + for field_name in unseeded_fields: + if field_name in categorical_fields: + categories = stats["columns"][field_name]["categories"] + if rare_category_replacement_method == RareCategoryReplacementMethod.sample and len(categories) > 1: + categories = [c for c in categories if c != CATEGORICAL_UNKNOWN_TOKEN] + model_dict[field_name] = (Literal[tuple(categories)], ...) # type: ignore[valid-type] + elif field_name in numeric_fields: + max_scale = stats["columns"][field_name]["max_scale"] + min_min5 = min(stats["columns"][field_name]["min5"]) + max_max5 = max(stats["columns"][field_name]["max5"]) + if max_scale == 0: + model_dict[field_name] = (SkipValidation[int], Field(ge=min_min5, le=max_max5)) + else: + model_dict[field_name] = ( + SkipValidation[float], + Field(ge=min_min5, le=max_max5, decimal_places=max_scale), + ) + elif field_name in datetime_fields: + model_dict[field_name] = ( + SkipValidation[str], + Field( + pattern=r"""(19\\d{2}|20\\d{2})-(0[1-9]|1[0-2])-(0[1-9]|1[0-9]|2[0-9]|3[0-1])T([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])""" + ), + ) + else: + model_dict[field_name] = (str, ...) schema = create_model("TargetModel", **model_dict, __base__=MostlyClassSchema) formatter_builder.append_str(f"{formatter_builder.json(schema, capture_name=None)}") formatter_builders.append(formatter_builder) @@ -112,3 +132,127 @@ def from_json(cls, _json: str) -> "MostlyClassSchema": f"Caught pydantic ValidationError {e}, reraising as JSONDecodeError", _json, 0 ) raise e + + +# copy formatron: direct copy from formatron +def _string_metadata(current: type, nonterminal: str): + min_length = current.metadata.get("min_length") + max_length = current.metadata.get("max_length") + pattern = current.metadata.get("pattern") + substring_of = current.metadata.get("substring_of") + if pattern: + assert not (min_length or max_length or substring_of), ( + "pattern is mutually exclusive with min_length, max_length and substring_of" + ) + if substring_of: + assert not (min_length or max_length or pattern), ( + "substring_of is mutually exclusive with min_length, max_length and pattern" + ) + repetition_map = { + (True, False): f"{{{min_length},}}", + (False, True): f"{{0,{max_length}}}", + (True, True): f"{{{min_length},{max_length}}}", + } + repetition = repetition_map.get((min_length is not None, max_length is not None)) + if repetition is not None: + return ( + rf"""{nonterminal} ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}}){repetition}"'; +""", + [], + ) + if pattern is not None: + pattern = pattern.replace("'", "\\'") + return f"""{nonterminal} ::= #'"{pattern}"';\n""", [] + if substring_of is not None: + return f"""{nonterminal} ::= '"' #substrs{repr(substring_of)} '"';\n""", [] + + +# completely altered vs formatron +def _number_metadata(current: type, nonterminal: str): + # For now only constrains number of digits and whether it is negative + gt = current.metadata.get("gt") + ge = current.metadata.get("ge") + lt = current.metadata.get("lt") + le = current.metadata.get("le") + if lt is not None or gt is not None: + raise NotImplementedError("gt and lt are not supported for number metadata") + if le < ge: + raise ValueError("le must be greater than or equal to ge") + + pattern_parts = [] + if issubclass(current.type, float): + le, le_frac = str(le).split(".") + ge, ge_frac = str(ge).split(".") + le, le_frac = int(le), int(le_frac) + ge, ge_frac = int(ge), int(ge_frac) + decimal_places = current.metadata.get("decimal_places") + + if ge is not None and le is not None: + if ge < 0 and le < 0: + pattern_parts.append("-") + min_num = abs(le) + max_num = abs(ge) + max_digits = len(str(max_num)) + min_digits = len(str(min_num)) + pattern_parts.append(rf"([1-9][0-9]{{{min_digits - 1},{max_digits - 1}}})") + elif ge > 0: + min_num = ge + max_num = le + max_digits = len(str(max_num)) + min_digits = len(str(min_num)) + pattern_parts.append(rf"([1-9][0-9]{{{min_digits - 1},{max_digits - 1}}})") + else: + if ge < 0: + pattern_parts.append("-?") + max_digits = max(len(str(abs(ge))), len(str(abs(le)))) + pattern_parts.append(rf"(0|[1-9][0-9]{{0,{max_digits - 1}}})") + + if issubclass(current.type, float): + pattern_parts.append(rf"(\\.[0-9]{{0,{decimal_places}}})?") + + pattern = "".join(pattern_parts) + return f"""{nonterminal} ::= #"{pattern}";\n""", [] + + +# copy formatron: removed sequence metadata since unnecessary and altered number_metadata to use ours +def _metadata(current: type, nonterminal: str): + if isinstance(current, schemas.schema.TypeWithMetadata): + original = typing.get_origin(current.type) + if original is None: + original = current.type + if not current.metadata: + return "", [(current.type, nonterminal)] + if isinstance(current.type, type) and issubclass(current.type, str): + return _string_metadata(current, nonterminal) + elif isinstance(current.type, type) and issubclass(current.type, (int, float)): + return _number_metadata(current, nonterminal) + return None + + +def monkey_patch_formatron(): + FORMATRON_WHITESPACE_MAX_REPETITIONS = 10 + SPACE_NONTERMINAL = f"[ \t\n\r]{{0,{FORMATRON_WHITESPACE_MAX_REPETITIONS}}}" + + # Copy from formatron, altered to have limited whitespace repetitions and datetime format + json.GRAMMAR_HEADER = rf"""integer ::= #"-?(0|[1-9]\\d*)"; + number ::= #"-?(0|[1-9]\\d*)(\\.\\d+)?([eE][+-]?\\d+)?"; + string ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}})*"'; + boolean ::= "true"|"false"; + null ::= "null"; + array ::= array_begin (json_value (comma json_value)*)? array_end; + object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end; + json_value ::= number|string|boolean|null|array|object; + comma ::= #"{SPACE_NONTERMINAL},{SPACE_NONTERMINAL}"; + colon ::= #"{SPACE_NONTERMINAL}:{SPACE_NONTERMINAL}"; + object_begin ::= #" \\{{{SPACE_NONTERMINAL}"; + object_end ::= #"{SPACE_NONTERMINAL}\\}}"; + array_begin ::= #"\\[{SPACE_NONTERMINAL}"; + array_end ::= #"{SPACE_NONTERMINAL}\\]"; + """ + + def alter_type_to_nonterminals_metadata_inplace(type_to_nonterminals: list[typing.Callable]): + metadata_idx = [idx for idx, fn in enumerate(type_to_nonterminals) if fn.__name__ == "metadata"] + if len(metadata_idx) == 1: + type_to_nonterminals[metadata_idx[0]] = _metadata + + alter_type_to_nonterminals_metadata_inplace(json._type_to_nonterminals) diff --git a/mostlyai/engine/_language/generation.py b/mostlyai/engine/_language/generation.py index 6c57cd5..e712af1 100644 --- a/mostlyai/engine/_language/generation.py +++ b/mostlyai/engine/_language/generation.py @@ -32,10 +32,13 @@ from mostlyai.engine._common import ( persist_data_part, FixedSizeSampleBuffer, - STRING, ProgressCallback, ProgressCallbackWrapper, ) +from mostlyai.engine._encoding_types.language.categorical import decode_language_categorical +from mostlyai.engine._encoding_types.language.datetime import decode_language_datetime +from mostlyai.engine._encoding_types.language.numeric import decode_language_numeric +from mostlyai.engine._encoding_types.language.text import decode_text from mostlyai.engine._language.common import estimate_max_tokens, MAX_LENGTH from mostlyai.engine._language.encoding import encode_df from mostlyai.engine._workspace import ensure_workspace_dir, Workspace, reset_dir @@ -44,6 +47,7 @@ prepare_seed_for_formatron, get_vocab_processors, ) +from mostlyai.engine.domain import ModelEncodingType, RareCategoryReplacementMethod INVALID_VALUE = "_INVALID_" # when JSON parsing fails, the values of target columns will be set to this DUMMY_CONTEXT_KEY = "__dummy_context_key" @@ -53,7 +57,7 @@ def decode_buffered_samples( buffer: FixedSizeSampleBuffer, tokenizer: PreTrainedTokenizerBase, - tgt_text_columns: list[str], + tgt_stats: dict[str, str], tgt_context_key: str, max_new_tokens: int, ): @@ -78,6 +82,7 @@ def parse_json(x, columns: list[str]): num_samples_max_length_limit += sum(1 for tokens in num_tokens_by_row if tokens >= max_new_tokens) except AttributeError: num_samples_max_length_limit = float("-inf") + outputs_text = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True) output_texts.extend(outputs_text) ctx_keys.append(keys_df) @@ -87,8 +92,8 @@ def parse_json(x, columns: list[str]): tgt_seed = pd.concat(tgt_seed, axis=0).reset_index(drop=True) # The model works with un-prefixed column names, but we need to recover prefixed column names for the final output tgt_data = pd.DataFrame( - [parse_json(text, tgt_text_columns) for text in output_texts], - columns=tgt_text_columns, + [parse_json(text, tgt_stats["columns"].keys()) for text in output_texts], + columns=tgt_stats["columns"].keys(), index=ctx_keys.index, dtype="string", ) @@ -98,14 +103,25 @@ def parse_json(x, columns: list[str]): ) # overwrite generated columns with the seeded values tgt_data.update(tgt_seed) - # ensure STRING type - tgt_data = tgt_data.astype(STRING) + # prepend the context keys to the data (if not dummy context) if ctx_keys.name != DUMMY_CONTEXT_KEY: tgt_data = pd.concat([ctx_keys, tgt_data], axis=1) - invalid_percentage = ((tgt_data[tgt_text_columns] == INVALID_VALUE).sum() / len(tgt_data) * 100.0).map( + invalid_percentage = ((tgt_data[tgt_stats["columns"].keys()] == INVALID_VALUE).sum() / len(tgt_data) * 100.0).map( "{:.2f}%".format ) + + for col in tgt_stats["columns"].keys(): + col_stats = tgt_stats["columns"][col] + if col_stats["encoding_type"] == ModelEncodingType.language_numeric: + tgt_data[col] = decode_language_numeric(tgt_data[col], col_stats) + elif col_stats["encoding_type"] == ModelEncodingType.language_datetime: + tgt_data[col] = decode_language_datetime(tgt_data[col], col_stats) + elif col_stats["encoding_type"] == ModelEncodingType.language_categorical: + tgt_data[col] = decode_language_categorical(tgt_data[col], col_stats) + else: + tgt_data[col] = decode_text(tgt_data[col], col_stats) + _LOG.info(f"percentage of invalid values: {invalid_percentage.to_dict()}") _LOG.info(f"decoded {tgt_data.shape} from {len(buffer.buffer)} batches in {time.time() - t0:.2f}s") return tgt_data @@ -119,6 +135,7 @@ def generate( batch_size: int | None = None, sampling_temperature: float = 1.0, sampling_top_p: float = 1.0, + rare_category_replacement_method: RareCategoryReplacementMethod | str = RareCategoryReplacementMethod.constant, device: torch.device | str | None = None, workspace_dir: str | Path = "engine-ws", update_progress: ProgressCallback | None = None, @@ -162,7 +179,6 @@ def tqdm_disabled(): if has_context: ctx_stats = workspace.ctx_stats.read() - ctx_columns = list(ctx_stats["columns"].keys()) ctx_primary_key = ctx_stats["keys"].get("primary_key") # ensure ctx_data exists @@ -187,11 +203,11 @@ def tqdm_disabled(): sample_size = len(ctx_data) _LOG.info(f"{sample_size=}") else: + ctx_stats = None # create on-the-fly context if sample_size is None: trn_sample_size = tgt_stats["no_of_training_records"] + tgt_stats["no_of_validation_records"] sample_size = trn_sample_size if sample_size is None else sample_size - ctx_columns = [] ctx_primary_key = tgt_context_key = DUMMY_CONTEXT_KEY ctx_data = pd.DataFrame({ctx_primary_key: range(sample_size)}) @@ -213,7 +229,7 @@ def tqdm_disabled(): return # encode context data - encoded_ctx_data = encode_df(ctx_df=ctx_data, ctx_columns=ctx_columns) + encoded_ctx_data = encode_df(ctx_df=ctx_data, ctx_stats=ctx_stats) # estimate max new tokens based on char length of original data; consider JSON overhead max_new_tokens = estimate_max_tokens(tgt_stats) @@ -247,7 +263,6 @@ def tqdm_disabled(): # prepare seed data for clean consumption by formatron seed_data = prepare_seed_for_formatron(seed_data, engine.tokenizer) seeded_tgt_columns = seed_data.columns.to_list() - unseeded_tgt_columns = [c for c in tgt_text_columns if c not in seeded_tgt_columns] total_tokenize_fn_time = 0 total_logits_processor_build_time = 0 @@ -259,7 +274,9 @@ def tqdm_disabled(): if enforce_json_output and len(seeded_tgt_columns) == 0: t0 = time.time() - formatter_builders = get_formatter_builders(size=batch_size, unseeded_fields=unseeded_tgt_columns) + formatter_builders = get_formatter_builders( + size=batch_size, stats=tgt_stats, rare_category_replacement_method=rare_category_replacement_method + ) engine.initialize_logits_processors(formatter_builders, formatron_vocab_processors) total_logits_processor_build_time += time.time() - t0 @@ -279,7 +296,8 @@ def tqdm_disabled(): # some columns are seeded, so we need to create a new logits processor for each batch formatter_builders = get_formatter_builders( seed_df=sample_seed_batch, - unseeded_fields=unseeded_tgt_columns, + stats=tgt_stats, + rare_category_replacement_method=rare_category_replacement_method, ) engine.initialize_logits_processors(formatter_builders, formatron_vocab_processors) total_logits_processor_build_time += time.time() - t0 @@ -295,7 +313,7 @@ def tqdm_disabled(): buffer.add((outputs, ctx_keys, sample_seed_batch)) if buffer.is_full(): decoded_data = decode_buffered_samples( - buffer, engine.tokenizer, tgt_text_columns, tgt_context_key, max_new_tokens + buffer, engine.tokenizer, tgt_stats, tgt_context_key, max_new_tokens ) persist_data_part( decoded_data, @@ -307,9 +325,7 @@ def tqdm_disabled(): samples_processed += len(ctx_batch) if not buffer.is_empty(): - decoded_data = decode_buffered_samples( - buffer, engine.tokenizer, tgt_text_columns, tgt_context_key, max_new_tokens - ) + decoded_data = decode_buffered_samples(buffer, engine.tokenizer, tgt_stats, tgt_context_key, max_new_tokens) persist_data_part( decoded_data, output_path, diff --git a/mostlyai/engine/_language/tokenizer_utils.py b/mostlyai/engine/_language/tokenizer_utils.py index ff2cba0..64c13e5 100644 --- a/mostlyai/engine/_language/tokenizer_utils.py +++ b/mostlyai/engine/_language/tokenizer_utils.py @@ -19,13 +19,19 @@ from transformers import DataCollatorForLanguageModeling, BatchEncoding, PreTrainedTokenizerFast, LlamaTokenizerFast from transformers.data.data_collator import pad_without_fast_tokenizer_warning, _torch_collate_batch +from mostlyai.engine.domain import ModelEncodingType + ################# ### TOKENIZER ### ################# -def train_tokenizer(training_iterator: Iterator | list | None = None, tokenizer_kwargs=None): +def train_tokenizer( + training_iterator: Iterator | list | None = None, + tokenizer_kwargs: dict[str, Any] | None = None, + tgt_stats: dict[str, Any] | None = None, +): if tokenizer_kwargs is None: tokenizer_kwargs = {} from tokenizers import Tokenizer, decoders @@ -46,10 +52,26 @@ def train_tokenizer(training_iterator: Iterator | list | None = None, tokenizer_ MIN_FREQ_MERGE = 20 VOCAB_SIZE = 5000 + # add initial alphabet for numeric and datetime columns if needed + has_numeric_columns = any( + col_stats["encoding_type"] == ModelEncodingType.language_numeric for col_stats in tgt_stats["columns"].values() + ) + has_datetime_columns = any( + col_stats["encoding_type"] == ModelEncodingType.language_datetime for col_stats in tgt_stats["columns"].values() + ) + initial_alphabet = set() + if has_numeric_columns: + # FIXME: maybe the set can be more fine-grained based on max_scale in stats + initial_alphabet |= {str(i) for i in range(10)} | {".", "-", "+", "e", "E"} + if has_datetime_columns: + initial_alphabet |= {str(i) for i in range(10)} | {".", "-", ":", "T", "Z"} + initial_alphabet = list(initial_alphabet) + # Builds a BPE raw_tokenizer, and optionally trains it based on provided text training_iterator = training_iterator or [] # allow easy training skip raw_tokenizer = Tokenizer(BPE(unk_token=special_tokens["unk_token"])) trainer = BpeTrainer( + initial_alphabet=initial_alphabet, special_tokens=SPECIAL_TOKENS, min_frequency=MIN_FREQ_MERGE, vocab_size=VOCAB_SIZE, diff --git a/mostlyai/engine/_language/training.py b/mostlyai/engine/_language/training.py index 4a21da8..15eebc4 100644 --- a/mostlyai/engine/_language/training.py +++ b/mostlyai/engine/_language/training.py @@ -38,7 +38,6 @@ from torch.utils.data import DataLoader from mostlyai.engine._common import ( - STRING, ProgressCallback, ProgressCallbackWrapper, TABLE_COLUMN_INFIX, @@ -272,7 +271,7 @@ def train( raw_dataset = load_dataset("parquet", data_files=data_files) def shuffle_tgt_columns(x): - x_tgt = pd.DataFrame([json.loads(x.pop("tgt"))], dtype=STRING) # convert to DataFrame + x_tgt = pd.DataFrame([json.loads(x.pop("tgt"))]) # convert to DataFrame x_tgt = x_tgt.sample(frac=1, axis=1) # shuffle columns x_tgt = row_to_json( x_tgt.add_prefix("tgt" + TABLE_COLUMN_INFIX).squeeze(axis=0), is_target=True @@ -352,7 +351,7 @@ def concat_prompt_and_response(x): for i in range(0, len(content_dataset["train"]), 1_000) ) # train a custom tokenizer and convert it to a LlamaTokenizerFast object - tokenizer = train_tokenizer(tokenizer_train_iter, tokenizer_kwargs=tokenizer_args) + tokenizer = train_tokenizer(tokenizer_train_iter, tokenizer_kwargs=tokenizer_args, tgt_stats=tgt_stats) model_config = LSTMFromScratchConfig(vocab_size=len(tokenizer), with_dp=with_dp) model = LSTMFromScratchLMHeadModel(model_config).to(device) else: diff --git a/mostlyai/engine/analysis.py b/mostlyai/engine/analysis.py index 1094c61..2a27b77 100644 --- a/mostlyai/engine/analysis.py +++ b/mostlyai/engine/analysis.py @@ -41,6 +41,11 @@ ProgressCallback, ProgressCallbackWrapper, ) +from mostlyai.engine._encoding_types.language.datetime import ( + analyze_reduce_language_datetime, + analyze_language_datetime, +) +from mostlyai.engine._encoding_types.language.numeric import analyze_language_numeric, analyze_reduce_language_numeric from mostlyai.engine._encoding_types.tabular.categorical import ( analyze_categorical, analyze_reduce_categorical, @@ -66,6 +71,10 @@ analyze_text, analyze_reduce_text, ) +from mostlyai.engine._encoding_types.language.categorical import ( + analyze_language_categorical, + analyze_reduce_language_categorical, +) from mostlyai.engine.domain import ModelEncodingType from mostlyai.engine._workspace import ( @@ -84,6 +93,9 @@ ModelEncodingType.tabular_numeric_binned, ModelEncodingType.tabular_datetime, ModelEncodingType.tabular_datetime_relative, + ModelEncodingType.language_categorical, + ModelEncodingType.language_numeric, + ModelEncodingType.language_datetime, ) @@ -313,22 +325,7 @@ def _analyze_reduce( column: column_stats.get("encoding_type") for column, column_stats in stats_list[0]["columns"].items() } - # build mapping of original column name to ARGN table and column identifiers - def get_table(qualified_column_name: str) -> str: - # column names are assumed to be :: - return qualified_column_name.split(TABLE_COLUMN_INFIX)[0] - - def get_unique_tables(qualified_column_names: Iterable[str]) -> list[str]: - duplicated_tables = [get_table(c) for c in qualified_column_names] - return list(dict.fromkeys(duplicated_tables)) - - unique_tables = get_unique_tables(encoding_types.keys()) - argn_identifiers: dict[str, tuple[str, str]] = { - c: (f"t{unique_tables.index(get_table(qualified_column_name=c))}", f"c{idx}") - for idx, c in enumerate(encoding_types.keys()) - } - - for i, column in enumerate(encoding_types.keys()): + for column in encoding_types: encoding_type = encoding_types[column] column_stats_list = [item["columns"][column] for item in stats_list] column_stats_list = [ @@ -379,6 +376,21 @@ def get_unique_tables(qualified_column_names: Iterable[str]) -> list[str]: ) elif encoding_type == ModelEncodingType.language_text: stats_col = analyze_reduce_text(stats_list=column_stats_list) + elif encoding_type == ModelEncodingType.language_categorical: + stats_col = analyze_reduce_text(stats_list=column_stats_list) | analyze_reduce_language_categorical( + stats_list=column_stats_list, + value_protection=value_protection, + ) + elif encoding_type == ModelEncodingType.language_numeric: + stats_col = analyze_reduce_text(stats_list=column_stats_list) | analyze_reduce_language_numeric( + stats_list=column_stats_list, + value_protection=value_protection, + ) + elif encoding_type == ModelEncodingType.language_datetime: + stats_col = analyze_reduce_text(stats_list=column_stats_list) | analyze_reduce_language_datetime( + stats_list=column_stats_list, + value_protection=value_protection, + ) else: raise RuntimeError(f"unknown encoding type {encoding_type}") @@ -388,29 +400,53 @@ def get_unique_tables(qualified_column_names: Iterable[str]) -> list[str]: if encoding_type in _VALUE_PROTECTION_ENCODING_TYPES: stats_col = {"value_protection": value_protection} | stats_col - # select model pipeline to process given column - def get_argn_processor(mode, is_flat) -> str: - if mode == "tgt": - return TGT - else: # mode == "ctx" - return CTXFLT if is_flat else CTXSEQ - - is_flat = "seq_len" not in column_stats_list[0] - stats_col[ARGN_PROCESSOR] = get_argn_processor(mode, is_flat) - ( - stats_col[ARGN_TABLE], - stats_col[ARGN_COLUMN], - ) = argn_identifiers[column] - - if not is_flat: + is_flat_column = "seq_len" not in column_stats_list[0] + if not is_flat_column: stats_col["seq_len"] = _analyze_reduce_seq_len([column_stats_list[0]["seq_len"]]) - if encoding_type == ModelEncodingType.language_text: - _LOG.info( - f"analyzed column `{column}`: {stats_col['encoding_type']} nchar_max={stats_col['nchar_max']} nchar_avg={stats_col['nchar_avg']}" + is_language_column = encoding_type in ( + ModelEncodingType.language_text, + ModelEncodingType.language_categorical, + ModelEncodingType.language_numeric, + ModelEncodingType.language_datetime, + ) + + if not is_language_column: + # build mapping of original column name to ARGN table and column identifiers + def get_table(qualified_column_name: str) -> str: + # column names are assumed to be
:: + return qualified_column_name.split(TABLE_COLUMN_INFIX)[0] + + def get_unique_tables(qualified_column_names: Iterable[str]) -> list[str]: + duplicated_tables = [get_table(c) for c in qualified_column_names] + return list(dict.fromkeys(duplicated_tables)) + + unique_tables = get_unique_tables(encoding_types.keys()) + argn_identifiers: dict[str, tuple[str, str]] = { + c: (f"t{unique_tables.index(get_table(qualified_column_name=c))}", f"c{idx}") + for idx, c in enumerate(encoding_types.keys()) + } + + def get_argn_processor(mode, is_flat) -> str: + if mode == "tgt": + return TGT + else: # mode == "ctx" + return CTXFLT if is_flat else CTXSEQ + + stats_col[ARGN_PROCESSOR] = get_argn_processor(mode, is_flat="seq_len" not in column_stats_list[0]) + ( + stats_col[ARGN_TABLE], + stats_col[ARGN_COLUMN], + ) = argn_identifiers[column] + + _LOG.info( + f"analyzed column `{column}`: {stats_col['encoding_type']} " + + ( + f"nchar_max={stats_col['nchar_max']} nchar_avg={stats_col['nchar_avg']}" + if is_language_column + else f"{stats_col['cardinalities']}" ) - else: - _LOG.info(f"analyzed column `{column}`: {stats_col['encoding_type']} {stats_col['cardinalities']}") + ) stats["columns"][column] = stats_col if mode == "ctx": @@ -513,6 +549,18 @@ def _analyze_flat_col( stats = analyze_latlong(values, root_keys, context_keys) elif encoding_type == ModelEncodingType.language_text: stats = analyze_text(values, root_keys, context_keys) + elif encoding_type == ModelEncodingType.language_categorical: + stats = analyze_text(values, root_keys, context_keys) | analyze_language_categorical( + values, root_keys, context_keys + ) + elif encoding_type == ModelEncodingType.language_numeric: + stats = analyze_text(values, root_keys, context_keys) | analyze_language_numeric( + values, root_keys, context_keys + ) + elif encoding_type == ModelEncodingType.language_datetime: + stats = analyze_text(values, root_keys, context_keys) | analyze_language_datetime( + values, root_keys, context_keys + ) else: raise RuntimeError(f"unknown encoding type: `{encoding_type}` for `{values.name}`") return stats diff --git a/mostlyai/engine/domain.py b/mostlyai/engine/domain.py index 6afab60..ac71625 100644 --- a/mostlyai/engine/domain.py +++ b/mostlyai/engine/domain.py @@ -47,7 +47,10 @@ class ModelEncodingType(str, Enum): - `TABULAR_DATETIME`: Model samples each part of a datetime value. - `TABULAR_DATETIME_RELATIVE`: Model samples the relative difference between datetimes within a sequence. - `TABULAR_LAT_LONG`: Model samples a latitude-longitude column. The format is "latitude,longitude". - - `LANGUAGE_TEXT`: Model will train a distinct LANGUAGE model for this column, to then generate free text. + - `LANGUAGE_TEXT`: Model will sample free text, using a LANGUAGE model. + - `LANGUAGE_CATEGORICAL`: Model samples from existing (non-rare) categories, using a LANGUAGE model. + - `LANGUAGE_NUMERIC`: Model samples from the valid numeric value range, using a LANGUAGE model. + - `LANGUAGE_DATETIME`: Model samples from the valid datetime value range, using a LANGUAGE model. """ auto = "AUTO" @@ -61,6 +64,9 @@ class ModelEncodingType(str, Enum): tabular_datetime_relative = "TABULAR_DATETIME_RELATIVE" tabular_lat_long = "TABULAR_LAT_LONG" language_text = "LANGUAGE_TEXT" + language_categorical = "LANGUAGE_CATEGORICAL" + language_numeric = "LANGUAGE_NUMERIC" + language_datetime = "LANGUAGE_DATETIME" class ModelStateStrategy(str, Enum): diff --git a/mostlyai/engine/generation.py b/mostlyai/engine/generation.py index efc8639..e2a1b92 100644 --- a/mostlyai/engine/generation.py +++ b/mostlyai/engine/generation.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from pathlib import Path import pandas as pd @@ -36,7 +35,7 @@ def generate( sampling_temperature: float = 1.0, sampling_top_p: float = 1.0, device: str | None = None, - rare_category_replacement_method: RareCategoryReplacementMethod | str | None = None, + rare_category_replacement_method: RareCategoryReplacementMethod | str = RareCategoryReplacementMethod.constant, rebalancing: RebalancingConfig | dict | None = None, imputation: ImputationConfig | dict | None = None, fairness: FairnessConfig | dict | None = None, @@ -76,9 +75,7 @@ def generate( batch_size=batch_size, sampling_temperature=sampling_temperature, sampling_top_p=sampling_top_p, - rare_category_replacement_method=inspect.signature(generate_tabular) - .parameters["rare_category_replacement_method"] - .default, + rare_category_replacement_method=rare_category_replacement_method, rebalancing=rebalancing, imputation=imputation, fairness=fairness, @@ -95,8 +92,6 @@ def generate( raise ValueError("fairness is not supported for language models") if rebalancing is not None: raise ValueError("rebalancing is not supported for language models") - if rare_category_replacement_method is not None: - raise ValueError("rare_category_replacement_method is not supported for language models") return generate_language( ctx_data=ctx_data, seed_data=seed_data, @@ -104,6 +99,7 @@ def generate( batch_size=batch_size, sampling_temperature=sampling_temperature, sampling_top_p=sampling_top_p, + rare_category_replacement_method=rare_category_replacement_method, device=device, workspace_dir=workspace_dir, update_progress=update_progress, diff --git a/tests/end_to_end/test_language.py b/tests/end_to_end/test_language.py index 38d3ea3..b5a3add 100644 --- a/tests/end_to_end/test_language.py +++ b/tests/end_to_end/test_language.py @@ -27,12 +27,17 @@ from mostlyai.engine._language.encoding import encode from mostlyai.engine.analysis import analyze from mostlyai.engine._common import TEMPORARY_PRIMARY_KEY +from mostlyai.engine._encoding_types.language.categorical import CATEGORICAL_UNKNOWN_TOKEN from mostlyai.engine._language.lstm import LSTMFromScratchConfig from mostlyai.engine._language.tokenizer_utils import MostlyDataCollatorForLanguageModeling from mostlyai.engine._language.training import train -from mostlyai.engine.domain import ModelEncodingType, ModelStateStrategy, DifferentialPrivacyConfig - -from mostlyai.engine._language.formatron_utils import get_formatter_builders +from mostlyai.engine.domain import ( + ModelEncodingType, + ModelStateStrategy, + DifferentialPrivacyConfig, + RareCategoryReplacementMethod, +) +from mostlyai.engine._language.formatron_utils import get_formatter_builders, _number_metadata from formatron.integrations.transformers import create_formatter_logits_processor_list @@ -256,7 +261,9 @@ def test_conditional_generation(tmp_path_factory): def test_formatter(): lone_leading_surrogate_issue = '{"E0": "[b]\\ud83c\\udc00\\ud83d\\ud8bc}{"}' unexpected_end_of_hex_escape_issue = '{"E0": "』』』\u200f』 avex\\ud8dd"}' - formatter_builders = get_formatter_builders(size=1, unseeded_fields=["some_field"]) + formatter_builders = get_formatter_builders( + size=1, stats={"columns": {}}, rare_category_replacement_method=RareCategoryReplacementMethod.constant + ) tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M", legacy=True) logits_processor = create_formatter_logits_processor_list(tokenizer, formatter_builders) formatter = logits_processor[0]._formatters[0] @@ -410,3 +417,136 @@ def test_special_character_column_name(tmp_path_factory): syn_data = pd.read_parquet(workspace_dir / "SyntheticData") assert len(syn_data) == 2 assert set(syn_data.columns) == set([TEMPORARY_PRIMARY_KEY] + list(tgt_encoding_types.keys())) + + +@pytest.fixture(scope="session") +def encoded_numeric_categorical_datetime_dataset(tmp_path_factory): + workspace_dir = tmp_path_factory.mktemp("ws") + no_of_records = 40 + data = pd.DataFrame( + { + "gender": ["m", "f", "x", pd.NA] * int(no_of_records / 4), + "age": [20, 30, 40, 50] * int(no_of_records / 4), + "date": [ + pd.Timestamp("2020-01-01"), + pd.Timestamp("2020-01-02"), + pd.Timestamp("2023-01-03"), + pd.Timestamp("2025-01-04"), + ] + * int(no_of_records / 4), + } + ) + rare_df = pd.DataFrame( + { + "gender": [f"rare{i + 1}" for i in range(20)], + "age": list(range(10, 20)) + list(range(51, 61)), + "date": ( + [pd.Timestamp("2019-01-01") + pd.Timedelta(days=i) for i in range(10)] + + [pd.Timestamp("2026-01-01") + pd.Timedelta(days=i) for i in range(10)] + ), + } + ) + data = pd.concat([data, rare_df], ignore_index=True) + tgt_encoding_types = { + "age": ModelEncodingType.language_numeric.value, + "gender": ModelEncodingType.language_categorical.value, + "date": ModelEncodingType.language_datetime.value, + } + split( + tgt_data=data, + workspace_dir=workspace_dir, + model_type="LANGUAGE", + tgt_encoding_types=tgt_encoding_types, + ) + analyze(workspace_dir=workspace_dir) + encode(workspace_dir=workspace_dir) + return workspace_dir + + +@pytest.mark.parametrize( + ("model_name"), + [ + LSTMFromScratchConfig.model_id, + "amd/AMD-Llama-135m", + "openai-community/gpt2", # TEMP, better model than AMD + ], +) +def test_categorical_numeric_datetime(encoded_numeric_categorical_datetime_dataset, model_name): + workspace_dir = encoded_numeric_categorical_datetime_dataset + train(workspace_dir=workspace_dir, model=model_name) + generate( + workspace_dir=workspace_dir, + sample_size=40, + rare_category_replacement_method=RareCategoryReplacementMethod.sample, + ) + + syn_data_path = workspace_dir / "SyntheticData" + syn = pd.read_parquet(syn_data_path) + assert len(syn) == 40 + assert set(syn.columns) == {"age", "gender", "date"} + + assert syn["age"].dtype == "Int64" + # test extreme value protection + assert syn["age"].min() >= 15 + assert syn["age"].max() <= 55 + + assert syn["gender"].dtype == "string" + # test rare category protection + assert "rare" not in syn["gender"].values + assert CATEGORICAL_UNKNOWN_TOKEN not in syn["gender"].values + assert syn["gender"].nunique(dropna=False) <= 4 + + assert syn["date"].dtype == "datetime64[ns]" + # test extreme value protection + dates = syn["date"].dropna() + if not dates.empty: + assert dates.min() >= pd.Timestamp("2019-01-06") + assert dates.max() <= pd.Timestamp("2026-01-05") + + +def test_number_metadata(): + class TypeWithMetadata: + def __init__(self, type, metadata): + self.type = type + self.metadata = metadata + + # test positive integer range + number_type = TypeWithMetadata(int, {"ge": 10, "le": 450}) + pattern, deps = _number_metadata(number_type, "test_number") + + assert deps == [] + # should match 2-3 digit numbers between 10-999 + assert 'test_number ::= #"([1-9][0-9]{1,2})";\n' in pattern + + # test negative integer range + number_type = TypeWithMetadata(int, {"ge": -269, "le": -10}) + pattern, deps = _number_metadata(number_type, "test_number") + + # should match negative 2-3 digit numbers + assert 'test_number ::= #"-([1-9][0-9]{1,2})";\n' in pattern + + # test range including both negative and positive + number_type = TypeWithMetadata(int, {"ge": -10, "le": 100}) + pattern, deps = _number_metadata(number_type, "test_number") + + # should allow optional negative sign and up to 3 digits and 0 + assert 'test_number ::= #"-?(0|[1-9][0-9]{0,2})";\n' in pattern + + # test float with decimal places + number_type = TypeWithMetadata(float, {"ge": 0.0, "le": 100.0, "decimal_places": 2}) + pattern, deps = _number_metadata(number_type, "test_number") + + # should match numbers with optional decimal part + assert r'test_number ::= #"(0|[1-9][0-9]{0,2})(\\.[0-9]{0,2})?";' + "\n" in pattern + + # test invalid range where le < ge + number_type = TypeWithMetadata(int, {"ge": 100, "le": 10}) + + with pytest.raises(ValueError, match="le must be greater than or equal to ge"): + _number_metadata(number_type, "test_number") + + # test unsupported gt/lt constraints + number_type = TypeWithMetadata(int, {"gt": 10, "lt": 100}) + + with pytest.raises(NotImplementedError, match="gt and lt are not supported for number metadata"): + _number_metadata(number_type, "test_number") diff --git a/tests/unit/encoding_types/language/__init__.py b/tests/unit/encoding_types/language/__init__.py new file mode 100644 index 0000000..a18e33e --- /dev/null +++ b/tests/unit/encoding_types/language/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/encoding_types/language/test_categorical.py b/tests/unit/encoding_types/language/test_categorical.py new file mode 100644 index 0000000..3721436 --- /dev/null +++ b/tests/unit/encoding_types/language/test_categorical.py @@ -0,0 +1,90 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import pytest +import numpy as np + +from mostlyai.engine._encoding_types.language.categorical import ( + CATEGORICAL_UNKNOWN_TOKEN, + decode_language_categorical, + analyze_language_categorical, + analyze_reduce_language_categorical, + encode_language_categorical, +) + + +class TestLanguageCategoricalAnalyze: + def test_3_frequent_and_1_rare_values(self): + values = pd.Series(np.repeat(["secret", "male", "female", pd.NA], 100), name="gender") + ids = pd.Series( + np.concatenate([np.repeat(0, 100), range(100), range(100, 200), range(200, 300)]), + name="subject_id", + ) + stats = analyze_language_categorical(values, ids) + assert stats == { + "cnt_values": {"female": 100, "male": 100, "secret": 1}, + "has_nan": True, + } + + +class TestLanguageCategoricalAnalyzeReduce: + @pytest.fixture + def stats_list(self): + stats1 = { + "cnt_values": {"secret1": 1, "male": 100}, + "has_nan": True, + } + stats2 = { + "cnt_values": {"secret2": 1, "male": 100, "female": 100}, + "has_nan": False, + } + return stats1, stats2 + + def test_with_value_protection(self, stats_list): + stats1, stats2 = stats_list + stats = analyze_reduce_language_categorical([stats1, stats2], value_protection=True) + assert stats == { + "categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "female", "male"], + "no_of_rare_categories": 2, + } + + +class TestLanguageCategoricalEncode: + def test_2_frequent_and_1_rare_and_1_null_values(self): + values = pd.Series(np.repeat(["secret", "male", "female", pd.NA], 100), name="gender") + stats = { + "categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "female", "male"], + "no_of_rare_categories": 1, + } + expected = pd.Series( + np.repeat([CATEGORICAL_UNKNOWN_TOKEN, "male", "female", pd.NA], 100), name="gender", dtype="string" + ) + encoded = encode_language_categorical(values, stats) + pd.testing.assert_series_equal(encoded, expected) + + +class TestLanguageCategoricalDecode: + @pytest.fixture + def col_stats(self): + return {"categories": [CATEGORICAL_UNKNOWN_TOKEN, None, "apple", "banana", "cherry"]} + + @pytest.fixture + def sample_values(self): + return pd.Series(["apple", "durian", "banana", "elderberry", "cherry", "fig", None]) + + def test_language_categorical_decode(self, sample_values, col_stats): + decoded = decode_language_categorical(sample_values, col_stats) + expected = pd.Series(["apple", None, "banana", None, "cherry", None, None], dtype=decoded.dtype) + pd.testing.assert_series_equal(decoded, expected) diff --git a/tests/unit/encoding_types/language/test_datetime.py b/tests/unit/encoding_types/language/test_datetime.py new file mode 100644 index 0000000..15eab3e --- /dev/null +++ b/tests/unit/encoding_types/language/test_datetime.py @@ -0,0 +1,155 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import pytest + +from mostlyai.engine._encoding_types.language.datetime import ( + analyze_language_datetime, + analyze_reduce_language_datetime, + decode_language_datetime, + encode_language_datetime, +) +from mostlyai.engine.domain import ModelEncodingType + + +class TestLanguageDatetimeAnalyze: + def test_analyze_language_datetime(self): + birth_dates = pd.Series( + [ + "1910-01-01", + "", + "1930-01-31", + "1940-02-12", + "", + "1971-09-01", + "1983-05-19", + "1998-05-24", + ] + * 11, + name="birth_date", + ) + keys = pd.Series(range(len(birth_dates)), name="id") + stats = analyze_language_datetime(birth_dates, keys) + assert stats["has_nan"] is True + assert stats["min11"] == ["1910-01-01"] * 11 + assert stats["max11"] == ["1998-05-24"] * 11 + + +class TestLanguageDatetimeAnalyzeReduce: + def test_analyze_reduce_language_datetime(self): + stats1 = { + "has_nan": True, + "min11": ["1910-01-01"] * 11, + "max11": ["1998-05-24"] * 11, + } + stats2 = { + "has_nan": False, + "min11": ["2000-01-01"] * 11, + "max11": ["2024-12-31"] * 11, + } + reduced = analyze_reduce_language_datetime([stats1, stats2]) + assert reduced["has_nan"] is True + assert reduced["min5"] == ["1910-01-01"] * 5 + assert reduced["max5"] == ["2024-12-31"] * 5 + + +class TestLanguageDatetimeEncode: + def test_encode_language_datetime(self): + values = pd.Series( + [ + "1910-01-01", + "", + "1930-01-31", + "1940-02-12", + "", + "1971-09-01", + "1983-05-19", + "1998-05-24", + ], + name="birth_date", + ) + stats = { + "has_nan": True, + "min5": ["1930-01-31"] * 5, + "max5": ["2024-12-31"] * 5, + } + encoded = encode_language_datetime(values, stats) + assert encoded.dtype == "datetime64[us]" + assert encoded.isna().sum() == 2 + assert encoded.iloc[0] == pd.Timestamp("1930-01-31") + assert encoded.iloc[1] is pd.NaT + assert encoded.iloc[2] == pd.Timestamp("1930-01-31") + assert encoded.iloc[3] == pd.Timestamp("1940-02-12") + assert encoded.iloc[4] is pd.NaT + assert encoded.iloc[5] == pd.Timestamp("1971-09-01") + assert encoded.iloc[6] == pd.Timestamp("1983-05-19") + + +class TestLanguageDatetimeDecode: + @pytest.fixture + def datetime_stats(self): + return { + "encoding_type": ModelEncodingType.language_datetime, + "has_nan": True, + "min5": ["2000-01-01"] * 5, + "max5": ["2024-12-31"] * 5, + } + + @pytest.fixture + def no_clip_stats(self): + return { + "encoding_type": ModelEncodingType.language_datetime, + "has_nan": True, + "min5": ["1900-01-01"] * 5, + "max5": ["2100-01-01"] * 5, + } + + @pytest.fixture + def sample_dates(self): + return pd.Series( + [ + "2021-05-20 14:30:00", # valid datetime with time + "2020-02-30", # Feb 30 is invalid; should be clamped to Feb 29, 2020 + "1999-12-31", # below the min bound -> will be clipped upward + "2025-01-01", # above the max bound -> will be clipped downward + "abcd", # invalid date string -> becomes NaT + "", # empty string -> becomes NaT + "_INVALID_", # marked as invalid -> becomes NaT + "2010-10-10", # valid date without explicit time (defaults to 00:00:00) + ] + ) + + def test_datetime_dtype_bounds_and_invalids(self, sample_dates, datetime_stats): + decoded = decode_language_datetime(sample_dates, datetime_stats) + assert decoded.dtype == "datetime64[ns]" + non_null = decoded.dropna() + min_bound = pd.to_datetime(datetime_stats["min5"][0]) + max_bound = pd.to_datetime(datetime_stats["max5"][0]) + for dt in non_null: + assert dt >= min_bound + assert dt <= max_bound + assert all(pd.isna(decoded.iloc[4:7])) + + def test_date_day_clamping(self, no_clip_stats): + s = pd.Series(["2021-04-31"]) + decoded = decode_language_datetime(s, no_clip_stats) + expected = pd.Timestamp("2021-04-30 00:00:00") + assert decoded.iloc[0] == expected + + def test_time_extraction(self, no_clip_stats): + s = pd.Series(["2021-07-15T23:59:59.123"]) + decoded = decode_language_datetime(s, no_clip_stats) + expected = pd.Timestamp("2021-07-15 23:59:59.123") + assert decoded.iloc[0] == expected diff --git a/tests/unit/encoding_types/language/test_numeric.py b/tests/unit/encoding_types/language/test_numeric.py new file mode 100644 index 0000000..1331468 --- /dev/null +++ b/tests/unit/encoding_types/language/test_numeric.py @@ -0,0 +1,124 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +import pytest + +from mostlyai.engine._encoding_types.language.numeric import ( + analyze_language_numeric, + analyze_reduce_language_numeric, + decode_language_numeric, + encode_language_numeric, +) +from mostlyai.engine.domain import ModelEncodingType + + +class TestLanguageNumericAnalyze: + def test_analyze_language_numeric(self): + values = pd.Series([0, 1, 2, 3, 4, 5] * 11, name="value") + ids = pd.Series(range(len(values)), name="id") + stats = analyze_language_numeric(values, ids) + assert stats["has_nan"] is False + assert stats["max11"] == [5] * 11 + assert stats["min11"] == [0] * 11 + + +class TestLanguageNumericAnalyzeReduce: + def test_analyze_reduce_language_numeric(self): + stats1 = { + "has_nan": False, + "max11": [5] * 11, + "min11": [0] * 11, + "max_scale": 0, + } + stats2 = { + "has_nan": True, + "max11": [10] * 11, + "min11": [6] * 11, + "max_scale": 1, + } + reduced = analyze_reduce_language_numeric([stats1, stats2]) + assert reduced["has_nan"] is True + assert reduced["max5"] == [10] * 5 + assert reduced["min5"] == [0] * 5 + assert reduced["max_scale"] == 1 + + +class TestLanguageNumericEncode: + def test_encode_language_numeric(self): + values = pd.Series([-1, 0, 1, 2, 3, 4, 5, 6], name="value") + stats = { + "has_nan": False, + "max5": [5] * 5, + "min5": [0] * 5, + "max_scale": 0, + } + encoded = encode_language_numeric(values, stats) + assert encoded.dtype == "Int64" + assert encoded.isna().sum() == 0 + assert encoded.iloc[0] == 0 + assert encoded.iloc[1] == 0 + assert encoded.iloc[2] == 1 + assert encoded.iloc[3] == 2 + assert encoded.iloc[4] == 3 + assert encoded.iloc[5] == 4 + assert encoded.iloc[6] == 5 + assert encoded.iloc[7] == 5 + + +class TestLanguageNumericDecode: + @pytest.fixture + def int_stats(self): + return { + "encoding_type": ModelEncodingType.language_numeric, + "has_nan": False, + "max5": [91] * 5, + "max_scale": 0, + "min5": [17] * 5, + } + + @pytest.fixture + def float_stats(self): + return { + "encoding_type": ModelEncodingType.language_numeric, + "has_nan": False, + "max5": [91.12] * 5, + "max_scale": 2, + "min5": [17.0] * 5, + } + + @pytest.fixture + def sample_values(self): + return pd.Series(["25.3541", "99.99", "-312.0", "61", None, "35.10091", "-1.223"]) + + @pytest.mark.parametrize( + "stats_name, expected_dtype", + [ + ("int_stats", "Int64"), + ("float_stats", float), + ], + ) + def test_decode_language_numeric(self, sample_values, request, stats_name, expected_dtype): + stats = request.getfixturevalue(stats_name) + decoded = decode_language_numeric(sample_values, stats) + assert decoded.dtype == expected_dtype + non_null = decoded.dropna() # we don't enforce compatability with "has_nan" + max_val = stats["max5"][0] + min_val = stats["min5"][0] + round_digits = stats["max_scale"] + for v in non_null: + assert np.isclose(v, round(v, round_digits), atol=1e-8) + assert all(non_null <= max_val) + assert all(non_null >= min_val) diff --git a/tests/unit/encoding_types/tabular/__init__.py b/tests/unit/encoding_types/tabular/__init__.py new file mode 100644 index 0000000..a18e33e --- /dev/null +++ b/tests/unit/encoding_types/tabular/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/encoding_types/test_categorical.py b/tests/unit/encoding_types/tabular/test_categorical.py similarity index 100% rename from tests/unit/encoding_types/test_categorical.py rename to tests/unit/encoding_types/tabular/test_categorical.py diff --git a/tests/unit/encoding_types/test_character.py b/tests/unit/encoding_types/tabular/test_character.py similarity index 100% rename from tests/unit/encoding_types/test_character.py rename to tests/unit/encoding_types/tabular/test_character.py diff --git a/tests/unit/encoding_types/test_datetime.py b/tests/unit/encoding_types/tabular/test_datetime.py similarity index 100% rename from tests/unit/encoding_types/test_datetime.py rename to tests/unit/encoding_types/tabular/test_datetime.py diff --git a/tests/unit/encoding_types/test_itt.py b/tests/unit/encoding_types/tabular/test_itt.py similarity index 100% rename from tests/unit/encoding_types/test_itt.py rename to tests/unit/encoding_types/tabular/test_itt.py diff --git a/tests/unit/encoding_types/test_lat_long.py b/tests/unit/encoding_types/tabular/test_lat_long.py similarity index 100% rename from tests/unit/encoding_types/test_lat_long.py rename to tests/unit/encoding_types/tabular/test_lat_long.py diff --git a/tests/unit/encoding_types/test_numeric.py b/tests/unit/encoding_types/tabular/test_numeric.py similarity index 100% rename from tests/unit/encoding_types/test_numeric.py rename to tests/unit/encoding_types/tabular/test_numeric.py diff --git a/tests/unit/test_encoding.py b/tests/unit/test_encoding.py index 62356af..0d7db9b 100644 --- a/tests/unit/test_encoding.py +++ b/tests/unit/test_encoding.py @@ -162,21 +162,21 @@ def test_long_sequential_values(self): class TestLanguageEncode: @pytest.fixture(scope="class") - def ctx_encoding_types(self): + def ctx_stats(self): return { - "table0::col_obj": ModelEncodingType.tabular_categorical, - "table1::col_int": ModelEncodingType.tabular_numeric_auto, - "table1::col_float": ModelEncodingType.tabular_numeric_auto, - "table1::col_bool": ModelEncodingType.tabular_categorical, - "table2::col_date": ModelEncodingType.tabular_datetime, - "table3::col_datetime": ModelEncodingType.tabular_datetime, + "columns": { + "table0::col_obj": {}, + "table1::col_int": {}, + "table1::col_float": {}, + "table1::col_bool": {}, + "table2::col_date": {}, + "table3::col_datetime": {}, + } } @pytest.fixture(scope="class") - def tgt_encoding_types(self): - return { - "table3::col_str": ModelEncodingType.language_text, - } + def tgt_stats(self): + return {"columns": {"table3::col_str": {}}} @pytest.fixture(scope="class") def ctx_df(self): @@ -208,9 +208,9 @@ def tgt_df(self): ) return df - def test_format_df(self, ctx_df, tgt_df, ctx_encoding_types, tgt_encoding_types): - formatted_ctx_df = format_df(ctx_df, is_target=False, columns=list(ctx_encoding_types.keys())) - formatted_tgt_df = format_df(tgt_df, is_target=True, columns=list(tgt_encoding_types.keys())) + def test_format_df(self, ctx_df, tgt_df, ctx_stats, tgt_stats): + formatted_ctx_df = format_df(ctx_df, is_target=False, stats=ctx_stats) + formatted_tgt_df = format_df(tgt_df, is_target=True, stats=tgt_stats) ctx = formatted_ctx_df.iloc[0] tgt = formatted_tgt_df.iloc[0]