diff --git a/.github/workflows/run-tests-cpu.yaml b/.github/workflows/run-tests-cpu.yaml index bf66900..6a36d09 100644 --- a/.github/workflows/run-tests-cpu.yaml +++ b/.github/workflows/run-tests-cpu.yaml @@ -32,16 +32,16 @@ jobs: run: uv sync --frozen --extra cpu - name: Run | Tests -> unit - run: uv run pytest tests/unit + run: uv run --no-sync pytest tests/unit - name: Build mkdocs - run: uv run mkdocs build --strict + run: uv run --no-sync mkdocs build --strict - name: Run tests -> end_to_end -> sequential - run: uv run pytest tests/end_to_end/test_tabular_sequential.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py - name: Run tests -> end_to_end -> sequential context - run: uv run pytest tests/end_to_end/test_tabular_sequential_context.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py run-tests-cpu-end-to-end-nonsequential: runs-on: ubuntu-latest @@ -66,4 +66,4 @@ jobs: run: uv sync --frozen --extra cpu --no-group docs - name: Run tests -> end_to_end all except sequential - run: uv run pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ + run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ diff --git a/.github/workflows/run-tests-gpu.yaml b/.github/workflows/run-tests-gpu.yaml index 3958cbc..269f8a9 100644 --- a/.github/workflows/run-tests-gpu.yaml +++ b/.github/workflows/run-tests-gpu.yaml @@ -42,10 +42,10 @@ jobs: run: nvidia-smi - name: Run tests -> end_to_end -> sequential - run: uv run pytest tests/end_to_end/test_tabular_sequential.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py - name: Run tests -> end_to_end -> sequential context - run: uv run pytest tests/end_to_end/test_tabular_sequential_context.py + run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py - name: Run tests -> end_to_end all except sequential - run: uv run pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ + run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/ diff --git a/Makefile b/Makefile index 853ab26..9fc72a1 100644 --- a/Makefile +++ b/Makefile @@ -12,11 +12,11 @@ install: # Install dependencies .PHONY: lint lint: ## Run lints - uv run pre-commit run --all-files + uv run --no-sync pre-commit run --all-files .PHONY: test test: ## Run tests - uv run pytest + uv run --no-sync pytest .PHONY: all all: clean install lint test ## Run the commands: clean install lint test diff --git a/mostlyai/engine/_encoding_types/language/categorical.py b/mostlyai/engine/_encoding_types/language/categorical.py new file mode 100644 index 0000000..aebfc39 --- /dev/null +++ b/mostlyai/engine/_encoding_types/language/categorical.py @@ -0,0 +1,75 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Categorical encoding for language models. +""" + +import numpy as np +import pandas as pd + +from mostlyai.engine._common import safe_convert_string, STRING + +CATEGORICAL_UNKNOWN_TOKEN = "_RARE_" + + +def analyze_language_categorical(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: + values = safe_convert_string(values) + # count distinct root_keys per categorical value for rare-category protection + df = pd.concat([root_keys, values], axis=1) + cnt_values = df.groupby(values.name)[root_keys.name].nunique().to_dict() + stats = {"has_nan": sum(values.isna()) > 0, "cnt_values": cnt_values} + return stats + + +def analyze_reduce_language_categorical(stats_list: list[dict], value_protection: bool = True) -> dict: + # sum up all counts for each categorical value + cnt_values: dict[str, int] = {} + for item in stats_list: + for value, count in item["cnt_values"].items(): + cnt_values[value] = cnt_values.get(value, 0) + count + # create alphabetically sorted list of non-rare categories + known_categories = [k for k in sorted(cnt_values.keys())] + if value_protection: + # stochastic threshold for rare categories + rare_min = 5 + int(3 * np.random.uniform()) + else: + rare_min = 0 + categories = [k for k in known_categories if cnt_values[k] >= rare_min] + no_of_rare_categories = len(known_categories) - len(categories) + # add None to categories, if any are present + if any([j["has_nan"] for j in stats_list]): + categories = [None] + categories + # add special token for UNKNOWN categories at first position + if no_of_rare_categories > 0: + categories = [CATEGORICAL_UNKNOWN_TOKEN] + categories + stats = {"no_of_rare_categories": no_of_rare_categories, "categories": categories} + return stats + + +def encode_language_categorical(values: pd.Series, stats: dict) -> pd.Series: + values = safe_convert_string(values) + values = values.copy() + known_categories = stats["categories"] + mask = ~values.isin(known_categories) + if None in known_categories: + mask &= ~pd.isna(values) + values[mask] = CATEGORICAL_UNKNOWN_TOKEN + return values + + +def decode_language_categorical(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + x = x.astype(STRING) + allowed_categories = col_stats.get("categories", []) + return x.where(x.isin(allowed_categories), other=None) diff --git a/mostlyai/engine/_encoding_types/language/datetime.py b/mostlyai/engine/_encoding_types/language/datetime.py new file mode 100644 index 0000000..c7a1b37 --- /dev/null +++ b/mostlyai/engine/_encoding_types/language/datetime.py @@ -0,0 +1,143 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import calendar + +import numpy as np +import pandas as pd + +from mostlyai.engine._common import safe_convert_datetime + + +def analyze_language_datetime(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: + values = safe_convert_datetime(values) + df = pd.concat([root_keys, values], axis=1) + # determine lowest/highest values by root ID, and return Top 10 + min_dates = df.groupby(root_keys.name)[values.name].min().dropna() + min11 = min_dates.sort_values(ascending=True).head(11).astype(str).tolist() + max_dates = df.groupby(root_keys.name)[values.name].max().dropna() + max11 = max_dates.sort_values(ascending=False).head(11).astype(str).tolist() + # determine if there are any NaN values + has_nan = bool(values.isna().any()) + # return stats + stats = { + "has_nan": has_nan, + "min11": min11, + "max11": max11, + } + return stats + + +def analyze_reduce_language_datetime(stats_list: list[dict], value_protection: bool = True) -> dict: + # check if there are missing values + has_nan = any([j["has_nan"] for j in stats_list]) + # determine min / max 5 values to map too low / too high values to + min11 = sorted([v for min11 in [j["min11"] for j in stats_list] for v in min11], reverse=False)[:11] + max11 = sorted([v for max11 in [j["max11"] for j in stats_list] for v in max11], reverse=True)[:11] + if value_protection: + # extreme value protection - discard lowest/highest 5 values + if len(min11) < 11 or len(max11) < 11: + # less than 11 subjects with non-NULL values; we need to protect all + min5 = [] + max5 = [] + else: + min5 = [str(v) for v in min11[5:10]] # drop 1 to 5th lowest; keep 6th to 10th lowest + max5 = [str(v) for v in max11[5:10]] # drop 1 to 5th highest; keep 6th to 10th highest + else: + min5 = min11[0:4] + max5 = max11[0:4] + stats = { + "has_nan": has_nan, + "min5": min5, + "max5": max5, + } + return stats + + +def encode_language_datetime(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.Series: + # convert + values = safe_convert_datetime(values) + values = values.copy() + # reset index, as `values.mask` can throw errors for misaligned indices + values.reset_index(drop=True, inplace=True) + # replace extreme values with randomly sampled 5-th to 10-th largest/smallest values + min5 = stats["min5"] if len(stats["min5"]) > 0 else [0] + max5 = stats["max5"] if len(stats["max5"]) > 0 else [0] + min5 = pd.Series(min5, dtype=values.dtype) + max5 = pd.Series(max5, dtype=values.dtype) + values.mask( + values < min5[0], + min5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + values.mask( + values > max5[0], + max5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + return values + + +def _clip_datetime(x: pd.Series, min5: list, max5: list) -> pd.Series: + x_dt = pd.to_datetime(x, errors="coerce") + min_arr = pd.to_datetime(min5).to_numpy(dtype="datetime64[ns]") + max_arr = pd.to_datetime(max5).to_numpy(dtype="datetime64[ns]") + n = len(x_dt) + random_mins = np.random.choice(min_arr, size=n) + random_maxs = np.random.choice(max_arr, size=n) + clipped = np.minimum(np.maximum(x_dt.to_numpy(dtype="datetime64[ns]"), random_mins), random_maxs) + return pd.Series(clipped, index=x.index) + + +def decode_language_datetime(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + x = x.where(~x.isin(["", "_INVALID_"]), np.nan) + + valid_mask = ( + x.str.len().ge(10) + & x.str.slice(0, 4).str.isdigit() + & x.str.slice(5, 7).str.isdigit() + & x.str.slice(8, 10).str.isdigit() + ) + if valid_mask.sum() > 0: # expected "YYYY-MM-DD" prefix + # handle the date portion, ensuring validity + years = x[valid_mask].str.slice(0, 4).astype(int) + months = x[valid_mask].str.slice(5, 7).astype(int) + days = x[valid_mask].str.slice(8, 10).astype(int) + + # clamp days according to maximum possible day of the month of a given year + last_days = np.array([calendar.monthrange(y, m)[1] for y, m in zip(years, months)]) + clamped_days = np.minimum(days, last_days) + + # rebuild the date portion + new_date = ( + years.astype(str).str.zfill(4) + + "-" + + months.astype(str).str.zfill(2) + + "-" + + pd.Series(clamped_days, index=years.index).astype(str).str.zfill(2) + ) + + # handle the time portion, ensuring validity + remainder = x[valid_mask].str.slice(10) + + time_regex = r"^[ T]?(\d{2}:\d{2}:\d{2}(?:\.\d+)?)" + valid_time = remainder.str.extract(time_regex, expand=False) + valid_time = valid_time.fillna("00:00:00") + valid_time = " " + valid_time + + new_date = new_date + valid_time + x.loc[valid_mask] = new_date + + x = pd.to_datetime(x, errors="coerce") + x = _clip_datetime(x, col_stats["min5"], col_stats["max5"]) + return x.astype("datetime64[ns]") diff --git a/mostlyai/engine/_encoding_types/language/numeric.py b/mostlyai/engine/_encoding_types/language/numeric.py new file mode 100644 index 0000000..a3723b4 --- /dev/null +++ b/mostlyai/engine/_encoding_types/language/numeric.py @@ -0,0 +1,136 @@ +# Copyright 2025 MOSTLY AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pandas as pd + +from mostlyai.engine._common import safe_convert_numeric +from mostlyai.engine._encoding_types.tabular.numeric import _type_safe_numeric_series +from mostlyai.engine.domain import ModelEncodingType + + +def analyze_language_numeric(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: + values = safe_convert_numeric(values) + + # determine lowest/highest values by root ID, and return top 11 + df = pd.concat([root_keys, values], axis=1) + min_values = df.groupby(root_keys.name)[values.name].min().dropna() + min11 = min_values.sort_values(ascending=True).head(11).tolist() + max_values = df.groupby(root_keys.name)[values.name].max().dropna() + max11 = max_values.sort_values(ascending=False).head(11).tolist() + + # determine if there are any NaN values + has_nan = bool(values.isna().any()) + + # determine max scale + def count_scale(num: float) -> int: + # represent number as fixed point string, remove trailing zeros and decimal point + num = format(num, "f").rstrip("0").rstrip(".") + if "." in num: + # in case of decimal, return number of digits after decimal point + return len(num.split(".")[1]) + # in case of integer, return 0 + return 0 + + max_scale = int(values.apply(count_scale).max()) + + stats = { + "has_nan": has_nan, + "max_scale": max_scale, + "min11": min11, + "max11": max11, + } + return stats + + +def analyze_reduce_language_numeric(stats_list: list[dict], value_protection: bool = True) -> dict: + # check for occurrence of NaN values + has_nan = any([j["has_nan"] for j in stats_list]) + + # determine max scale + max_scale = max([j["max_scale"] for j in stats_list]) + + # determine min / max 5 values to map too low / too high values to + min11 = sorted([v for min11 in [j["min11"] for j in stats_list] for v in min11], reverse=False)[:11] + max11 = sorted([v for max11 in [j["max11"] for j in stats_list] for v in max11], reverse=True)[:11] + if value_protection: + # extreme value protection - discard lowest/highest 5 values + if len(min11) < 11 or len(max11) < 11: + # less than 11 subjects with non-NULL values; we need to protect all + min5 = [] + max5 = [] + else: + min5 = min11[5:10] # drop 1 to 5th lowest; keep 6th to 10th lowest + max5 = max11[5:10] # drop 1 to 5th highest; keep 6th to 10th highest + else: + min5 = min11[0:5] + max5 = max11[0:5] + + stats = { + "encoding_type": ModelEncodingType.language_numeric.value, + "has_nan": has_nan, + "max_scale": max_scale, + "min5": min5, + "max5": max5, + } + + return stats + + +def encode_language_numeric(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.DataFrame: + values = safe_convert_numeric(values) + # try to convert to int, if possible + dtype = "Int64" if stats["max_scale"] == 0 else "Float64" + if dtype == "Int64": + values = values.round() + try: + values = values.astype(dtype) + except TypeError: + if dtype == "Int64": # if couldn't safely convert to int, stick to float + dtype = "Float64" + values = values.astype(dtype) + # reset index, as `values.mask` can throw errors for misaligned indices + values.reset_index(drop=True, inplace=True) + # replace extreme values with randomly sampled 5-th to 10-th largest/smallest values + min5 = _type_safe_numeric_series(stats["min5"] or [0], dtype) + max5 = _type_safe_numeric_series(stats["max5"] or [0], dtype) + values.mask( + values < min5[0], + min5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + values.mask( + values > max5[0], + max5.sample(n=len(values), replace=True, ignore_index=True), + inplace=True, + ) + return values + + +def _clip_numeric(x: pd.Series, min5: list, max5: list) -> pd.Series: + x_numeric = pd.to_numeric(x, errors="coerce") + min_arr = np.array(min5, dtype=x_numeric.dtype) + max_arr = np.array(max5, dtype=x_numeric.dtype) + n = len(x_numeric) + random_mins = np.random.choice(min_arr, size=n) + random_maxs = np.random.choice(max_arr, size=n) + clipped = np.minimum(np.maximum(x_numeric.to_numpy(), random_mins), random_maxs) + return pd.Series(clipped, index=x.index) + + +def decode_language_numeric(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + x = pd.to_numeric(x, errors="coerce") + x = x.round(col_stats["max_scale"]) + x = _clip_numeric(x, col_stats["min5"], col_stats["max5"]) + dtype = "Int64" if col_stats["max_scale"] == 0 else float + return x.astype(dtype) diff --git a/mostlyai/engine/_encoding_types/language/text.py b/mostlyai/engine/_encoding_types/language/text.py index 2e89b75..245699a 100644 --- a/mostlyai/engine/_encoding_types/language/text.py +++ b/mostlyai/engine/_encoding_types/language/text.py @@ -14,7 +14,7 @@ import pandas as pd -from mostlyai.engine._common import safe_convert_string +from mostlyai.engine._common import safe_convert_string, STRING def analyze_text(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict: @@ -39,3 +39,7 @@ def analyze_reduce_text(stats_list: list[dict], _: bool = True) -> dict: "nchar_max": nchar_max, } return stats + + +def decode_text(x: pd.Series, col_stats: dict[str, str]) -> pd.Series: + return x.astype(STRING) diff --git a/mostlyai/engine/_encoding_types/tabular/numeric.py b/mostlyai/engine/_encoding_types/tabular/numeric.py index 44cf58c..8edff93 100644 --- a/mostlyai/engine/_encoding_types/tabular/numeric.py +++ b/mostlyai/engine/_encoding_types/tabular/numeric.py @@ -165,7 +165,7 @@ def analyze_numeric( # do not count values, if there are too many cnt_values = None - # determine lowest/highest values by root ID, and return Top 10 + # determine lowest/highest values by root ID, and return top 11 df = pd.concat([root_keys, values], axis=1) min_values = df.groupby(root_keys.name)[values.name].min().dropna() min11 = min_values.sort_values(ascending=True).head(11).astype("float").tolist() diff --git a/mostlyai/engine/_language/encoding.py b/mostlyai/engine/_language/encoding.py index cb74ada..752624f 100644 --- a/mostlyai/engine/_language/encoding.py +++ b/mostlyai/engine/_language/encoding.py @@ -24,17 +24,36 @@ from mostlyai.engine._common import is_sequential, ProgressCallback, ProgressCallbackWrapper, TABLE_COLUMN_INFIX from mostlyai.engine._workspace import ensure_workspace_dir, Workspace, reset_dir +from mostlyai.engine._encoding_types.language.categorical import encode_language_categorical +from mostlyai.engine._encoding_types.language.numeric import encode_language_numeric +from mostlyai.engine._encoding_types.language.datetime import encode_language_datetime _LOG = logging.getLogger(__name__) -def format_df(df: pd.DataFrame, columns: list[str], is_target: bool = False) -> pd.DataFrame: - df = df[columns].copy() +def apply_encoding_types(df: pd.DataFrame, stats: dict) -> pd.DataFrame: + for col, col_stats in stats["columns"].items(): + if col_stats["encoding_type"] == "LANGUAGE_CATEGORICAL": + df[col] = encode_language_categorical(df[col], col_stats) + elif col_stats["encoding_type"] == "LANGUAGE_NUMERIC": + df[col] = encode_language_numeric(df[col], col_stats) + elif col_stats["encoding_type"] == "LANGUAGE_DATETIME": + df[col] = encode_language_datetime(df[col], col_stats) + return df + + +def drop_sequential_columns(df: pd.DataFrame) -> pd.DataFrame: # Some columns (e.g., SCP columns) may contain np.ndarray, which are not JSON serializable # We need to drop them before converting the DataFrame to JSON sequential_columns = [col for col in df.columns if is_sequential(df[col])] df = df.drop(columns=sequential_columns) - _LOG.info(f"Formatting {'target' if is_target else 'context'} columns {df.columns.tolist()} to JSON") + return df + + +def format_df(df: pd.DataFrame, stats: dict, is_target: bool = False) -> pd.DataFrame: + columns = list(stats["columns"].keys()) + df = df[columns].copy() + _LOG.info(f"Formatting {'target' if is_target else 'context'} columns {columns} to JSON") # convert date format to ISO so that it's JSON serializable for col in df.columns: if is_datetime64_any_dtype(df[col]): @@ -76,15 +95,21 @@ def row_to_json(row: pd.Series, is_target: bool = False) -> str: def encode_df( ctx_df: pd.DataFrame, - ctx_columns: list[str], + ctx_stats: dict | None = None, tgt_df: pd.DataFrame | None = None, - tgt_columns: list[str] | None = None, + tgt_stats: dict | None = None, ) -> pd.DataFrame: - assert (tgt_df is None) == (tgt_columns is None), "tgt_df and tgt_columns must be both None or both not None" + assert (tgt_df is None) == (tgt_stats is None), "tgt_df and tgt_stats must be both None or both not None" + if ctx_stats is None: + ctx_stats = {"columns": {}} df = pd.DataFrame() - df["ctx"] = format_df(ctx_df, columns=ctx_columns, is_target=False) - if tgt_df is not None and tgt_columns is not None: - df["tgt"] = format_df(tgt_df, columns=tgt_columns, is_target=True) + ctx_df = drop_sequential_columns(ctx_df) + ctx_df = apply_encoding_types(ctx_df, stats=ctx_stats) + df["ctx"] = format_df(ctx_df, stats=ctx_stats, is_target=False) + if tgt_df is not None and tgt_stats is not None: + tgt_df = drop_sequential_columns(tgt_df) + tgt_df = apply_encoding_types(tgt_df, stats=tgt_stats) + df["tgt"] = format_df(tgt_df, stats=tgt_stats, is_target=True) # log the bounds of n_tokens in this partition content = df["ctx"] + df["tgt"] if "tgt" in df.columns else df["ctx"] @@ -107,19 +132,16 @@ def _encode_partition( ctx_stats: dict | None = None, ) -> None: tgt_df = pd.read_parquet(tgt_partition_file) - tgt_columns = list(tgt_stats.get("columns", {}).keys()) if ctx_partition_file: ctx_df = pd.read_parquet(ctx_partition_file) - ctx_columns = list(ctx_stats.get("columns", {}).keys()) else: # create on-the-fly context ctx_df = pd.DataFrame(index=range(len(tgt_df))) - ctx_columns = [] df = encode_df( ctx_df=ctx_df, - ctx_columns=ctx_columns, + ctx_stats=ctx_stats, tgt_df=tgt_df, - tgt_columns=tgt_columns, + tgt_stats=tgt_stats, ) # shuffle and persist to disk as parquet files df = df.sample(frac=1) diff --git a/mostlyai/engine/_language/engine/hf_engine.py b/mostlyai/engine/_language/engine/hf_engine.py index 18c3c5f..88f30a2 100644 --- a/mostlyai/engine/_language/engine/hf_engine.py +++ b/mostlyai/engine/_language/engine/hf_engine.py @@ -23,8 +23,8 @@ from transformers import AutoTokenizer from mostlyai.engine._language.common import load_base_model_and_config -from mostlyai.engine._language.tokenizer_utils import tokenize_fn from mostlyai.engine._language.formatron_utils import monkey_patch_formatron +from mostlyai.engine._language.tokenizer_utils import tokenize_fn from mostlyai.engine._language.engine.base import EngineMetrics, LanguageEngine from formatron.formatter import FormatterBuilder @@ -66,11 +66,8 @@ def __init__( self.tokenizer.special_tokens_map ) self._json_enforcing_possible = is_peft_adapter or is_trained_lstm_tokenizer - - # apply all necessary monkey patches to the formatron library - if self._json_enforcing_possible: + if self.supports_json_enforcing(): monkey_patch_formatron() - self._logits_processors = None def get_default_batch_size(self) -> int: diff --git a/mostlyai/engine/_language/engine/vllm_engine.py b/mostlyai/engine/_language/engine/vllm_engine.py index 247479a..afa1268 100644 --- a/mostlyai/engine/_language/engine/vllm_engine.py +++ b/mostlyai/engine/_language/engine/vllm_engine.py @@ -26,6 +26,7 @@ from peft import PeftConfig from transformers import AutoTokenizer, AutoConfig, PreTrainedTokenizerBase +from mostlyai.engine._language.formatron_utils import monkey_patch_formatron from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest from vllm.config import _get_and_verify_max_len @@ -123,6 +124,7 @@ def __init__( add_eos_token=False, ) self._logits_processors = None + monkey_patch_formatron() def get_default_batch_size(self) -> int: return 192 diff --git a/mostlyai/engine/_language/formatron_utils.py b/mostlyai/engine/_language/formatron_utils.py index d6ef79f..ec6df8b 100644 --- a/mostlyai/engine/_language/formatron_utils.py +++ b/mostlyai/engine/_language/formatron_utils.py @@ -12,17 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. + import typing + import pandas as pd from formatron.schemas.pydantic import ClassSchema from json import JSONDecodeError -from pydantic import ValidationError +from pydantic import Field, SkipValidation, ValidationError from formatron.formatter import FormatterBuilder -from typing import Literal +from formatron import schemas from formatron.formats import json +from typing import Literal from pydantic import create_model from transformers import PreTrainedTokenizerBase +from mostlyai.engine._encoding_types.language.categorical import CATEGORICAL_UNKNOWN_TOKEN +from mostlyai.engine.domain import ModelEncodingType, RareCategoryReplacementMethod JSON_NULL = "null" @@ -41,41 +46,56 @@ def transform(x: str | None) -> str: return sample_seed.astype("string[pyarrow]").map(transform) -def monkey_patch_formatron(): - # alter the Grammar of formatron's json schema - FORMATRON_WHITESPACE_MAX_REPETITIONS = 10 - SPACE_NONTERMINAL = f"[ \t\n\r]{{0,{FORMATRON_WHITESPACE_MAX_REPETITIONS}}}" - - json.GRAMMAR_HEADER = rf"""integer ::= #"-?(0|[1-9]\\d*)"; - number ::= #"-?(0|[1-9]\\d*)(\\.\\d+)?([eE][+-]?\\d+)?"; - string ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}})*"'; - boolean ::= "true"|"false"; - null ::= "null"; - array ::= array_begin (json_value (comma json_value)*)? array_end; - object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end; - json_value ::= number|string|boolean|null|array|object; - comma ::= #"{SPACE_NONTERMINAL},{SPACE_NONTERMINAL}"; - colon ::= #"{SPACE_NONTERMINAL}:{SPACE_NONTERMINAL}"; - object_begin ::= #" \\{{{SPACE_NONTERMINAL}"; - object_end ::= #"{SPACE_NONTERMINAL}\\}}"; - array_begin ::= #"\\[{SPACE_NONTERMINAL}"; - array_end ::= #"{SPACE_NONTERMINAL}\\]"; - """ - - def get_formatter_builders( - *, seed_df: pd.DataFrame | None = None, size: int | None = None, unseeded_fields: list[str] + *, + seed_df: pd.DataFrame | None = None, + size: int | None = None, + stats: dict, + rare_category_replacement_method: RareCategoryReplacementMethod, ) -> list[FormatterBuilder]: assert (seed_df is not None) ^ (size is not None), "exactly one of seed_df or size must be provided" formatter_builders = [] if seed_df is None: seed_df = pd.DataFrame(index=range(size)) + unseeded_fields = [c for c in list(stats["columns"].keys()) if c not in seed_df.columns.to_list()] + field_types = { + t: [col for col, col_stats in stats["columns"].items() if col_stats["encoding_type"] == t] + for t in ModelEncodingType + } + categorical_fields = field_types.get(ModelEncodingType.language_categorical, []) + numeric_fields = field_types.get(ModelEncodingType.language_numeric, []) + datetime_fields = field_types.get(ModelEncodingType.language_datetime, []) for _, seed_row in seed_df.iterrows(): formatter_builder = FormatterBuilder() model_dict = {} if not seed_row.empty: - model_dict |= {field_name: (Literal[seed_value], ...) for field_name, seed_value in seed_row.items()} - model_dict |= {field_name: (str, ...) for field_name in unseeded_fields} + model_dict |= {field_name: (Literal[seed_value], ...) for field_name, seed_value in seed_row.items()} # type: ignore[valid-type] + for field_name in unseeded_fields: + if field_name in categorical_fields: + categories = stats["columns"][field_name]["categories"] + if rare_category_replacement_method == RareCategoryReplacementMethod.sample and len(categories) > 1: + categories = [c for c in categories if c != CATEGORICAL_UNKNOWN_TOKEN] + model_dict[field_name] = (Literal[tuple(categories)], ...) # type: ignore[valid-type] + elif field_name in numeric_fields: + max_scale = stats["columns"][field_name]["max_scale"] + min_min5 = min(stats["columns"][field_name]["min5"]) + max_max5 = max(stats["columns"][field_name]["max5"]) + if max_scale == 0: + model_dict[field_name] = (SkipValidation[int], Field(ge=min_min5, le=max_max5)) + else: + model_dict[field_name] = ( + SkipValidation[float], + Field(ge=min_min5, le=max_max5, decimal_places=max_scale), + ) + elif field_name in datetime_fields: + model_dict[field_name] = ( + SkipValidation[str], + Field( + pattern=r"""(19\\d{2}|20\\d{2})-(0[1-9]|1[0-2])-(0[1-9]|1[0-9]|2[0-9]|3[0-1])T([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])""" + ), + ) + else: + model_dict[field_name] = (str, ...) schema = create_model("TargetModel", **model_dict, __base__=MostlyClassSchema) formatter_builder.append_str(f"{formatter_builder.json(schema, capture_name=None)}") formatter_builders.append(formatter_builder) @@ -112,3 +132,127 @@ def from_json(cls, _json: str) -> "MostlyClassSchema": f"Caught pydantic ValidationError {e}, reraising as JSONDecodeError", _json, 0 ) raise e + + +# copy formatron: direct copy from formatron +def _string_metadata(current: type, nonterminal: str): + min_length = current.metadata.get("min_length") + max_length = current.metadata.get("max_length") + pattern = current.metadata.get("pattern") + substring_of = current.metadata.get("substring_of") + if pattern: + assert not (min_length or max_length or substring_of), ( + "pattern is mutually exclusive with min_length, max_length and substring_of" + ) + if substring_of: + assert not (min_length or max_length or pattern), ( + "substring_of is mutually exclusive with min_length, max_length and pattern" + ) + repetition_map = { + (True, False): f"{{{min_length},}}", + (False, True): f"{{0,{max_length}}}", + (True, True): f"{{{min_length},{max_length}}}", + } + repetition = repetition_map.get((min_length is not None, max_length is not None)) + if repetition is not None: + return ( + rf"""{nonterminal} ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}}){repetition}"'; +""", + [], + ) + if pattern is not None: + pattern = pattern.replace("'", "\\'") + return f"""{nonterminal} ::= #'"{pattern}"';\n""", [] + if substring_of is not None: + return f"""{nonterminal} ::= '"' #substrs{repr(substring_of)} '"';\n""", [] + + +# completely altered vs formatron +def _number_metadata(current: type, nonterminal: str): + # For now only constrains number of digits and whether it is negative + gt = current.metadata.get("gt") + ge = current.metadata.get("ge") + lt = current.metadata.get("lt") + le = current.metadata.get("le") + if lt is not None or gt is not None: + raise NotImplementedError("gt and lt are not supported for number metadata") + if le < ge: + raise ValueError("le must be greater than or equal to ge") + + pattern_parts = [] + if issubclass(current.type, float): + le, le_frac = str(le).split(".") + ge, ge_frac = str(ge).split(".") + le, le_frac = int(le), int(le_frac) + ge, ge_frac = int(ge), int(ge_frac) + decimal_places = current.metadata.get("decimal_places") + + if ge is not None and le is not None: + if ge < 0 and le < 0: + pattern_parts.append("-") + min_num = abs(le) + max_num = abs(ge) + max_digits = len(str(max_num)) + min_digits = len(str(min_num)) + pattern_parts.append(rf"([1-9][0-9]{{{min_digits - 1},{max_digits - 1}}})") + elif ge > 0: + min_num = ge + max_num = le + max_digits = len(str(max_num)) + min_digits = len(str(min_num)) + pattern_parts.append(rf"([1-9][0-9]{{{min_digits - 1},{max_digits - 1}}})") + else: + if ge < 0: + pattern_parts.append("-?") + max_digits = max(len(str(abs(ge))), len(str(abs(le)))) + pattern_parts.append(rf"(0|[1-9][0-9]{{0,{max_digits - 1}}})") + + if issubclass(current.type, float): + pattern_parts.append(rf"(\\.[0-9]{{0,{decimal_places}}})?") + + pattern = "".join(pattern_parts) + return f"""{nonterminal} ::= #"{pattern}";\n""", [] + + +# copy formatron: removed sequence metadata since unnecessary and altered number_metadata to use ours +def _metadata(current: type, nonterminal: str): + if isinstance(current, schemas.schema.TypeWithMetadata): + original = typing.get_origin(current.type) + if original is None: + original = current.type + if not current.metadata: + return "", [(current.type, nonterminal)] + if isinstance(current.type, type) and issubclass(current.type, str): + return _string_metadata(current, nonterminal) + elif isinstance(current.type, type) and issubclass(current.type, (int, float)): + return _number_metadata(current, nonterminal) + return None + + +def monkey_patch_formatron(): + FORMATRON_WHITESPACE_MAX_REPETITIONS = 10 + SPACE_NONTERMINAL = f"[ \t\n\r]{{0,{FORMATRON_WHITESPACE_MAX_REPETITIONS}}}" + + # Copy from formatron, altered to have limited whitespace repetitions and datetime format + json.GRAMMAR_HEADER = rf"""integer ::= #"-?(0|[1-9]\\d*)"; + number ::= #"-?(0|[1-9]\\d*)(\\.\\d+)?([eE][+-]?\\d+)?"; + string ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}})*"'; + boolean ::= "true"|"false"; + null ::= "null"; + array ::= array_begin (json_value (comma json_value)*)? array_end; + object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end; + json_value ::= number|string|boolean|null|array|object; + comma ::= #"{SPACE_NONTERMINAL},{SPACE_NONTERMINAL}"; + colon ::= #"{SPACE_NONTERMINAL}:{SPACE_NONTERMINAL}"; + object_begin ::= #" \\{{{SPACE_NONTERMINAL}"; + object_end ::= #"{SPACE_NONTERMINAL}\\}}"; + array_begin ::= #"\\[{SPACE_NONTERMINAL}"; + array_end ::= #"{SPACE_NONTERMINAL}\\]"; + """ + + def alter_type_to_nonterminals_metadata_inplace(type_to_nonterminals: list[typing.Callable]): + metadata_idx = [idx for idx, fn in enumerate(type_to_nonterminals) if fn.__name__ == "metadata"] + if len(metadata_idx) == 1: + type_to_nonterminals[metadata_idx[0]] = _metadata + + alter_type_to_nonterminals_metadata_inplace(json._type_to_nonterminals) diff --git a/mostlyai/engine/_language/generation.py b/mostlyai/engine/_language/generation.py index 6c57cd5..e712af1 100644 --- a/mostlyai/engine/_language/generation.py +++ b/mostlyai/engine/_language/generation.py @@ -32,10 +32,13 @@ from mostlyai.engine._common import ( persist_data_part, FixedSizeSampleBuffer, - STRING, ProgressCallback, ProgressCallbackWrapper, ) +from mostlyai.engine._encoding_types.language.categorical import decode_language_categorical +from mostlyai.engine._encoding_types.language.datetime import decode_language_datetime +from mostlyai.engine._encoding_types.language.numeric import decode_language_numeric +from mostlyai.engine._encoding_types.language.text import decode_text from mostlyai.engine._language.common import estimate_max_tokens, MAX_LENGTH from mostlyai.engine._language.encoding import encode_df from mostlyai.engine._workspace import ensure_workspace_dir, Workspace, reset_dir @@ -44,6 +47,7 @@ prepare_seed_for_formatron, get_vocab_processors, ) +from mostlyai.engine.domain import ModelEncodingType, RareCategoryReplacementMethod INVALID_VALUE = "_INVALID_" # when JSON parsing fails, the values of target columns will be set to this DUMMY_CONTEXT_KEY = "__dummy_context_key" @@ -53,7 +57,7 @@ def decode_buffered_samples( buffer: FixedSizeSampleBuffer, tokenizer: PreTrainedTokenizerBase, - tgt_text_columns: list[str], + tgt_stats: dict[str, str], tgt_context_key: str, max_new_tokens: int, ): @@ -78,6 +82,7 @@ def parse_json(x, columns: list[str]): num_samples_max_length_limit += sum(1 for tokens in num_tokens_by_row if tokens >= max_new_tokens) except AttributeError: num_samples_max_length_limit = float("-inf") + outputs_text = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True) output_texts.extend(outputs_text) ctx_keys.append(keys_df) @@ -87,8 +92,8 @@ def parse_json(x, columns: list[str]): tgt_seed = pd.concat(tgt_seed, axis=0).reset_index(drop=True) # The model works with un-prefixed column names, but we need to recover prefixed column names for the final output tgt_data = pd.DataFrame( - [parse_json(text, tgt_text_columns) for text in output_texts], - columns=tgt_text_columns, + [parse_json(text, tgt_stats["columns"].keys()) for text in output_texts], + columns=tgt_stats["columns"].keys(), index=ctx_keys.index, dtype="string", ) @@ -98,14 +103,25 @@ def parse_json(x, columns: list[str]): ) # overwrite generated columns with the seeded values tgt_data.update(tgt_seed) - # ensure STRING type - tgt_data = tgt_data.astype(STRING) + # prepend the context keys to the data (if not dummy context) if ctx_keys.name != DUMMY_CONTEXT_KEY: tgt_data = pd.concat([ctx_keys, tgt_data], axis=1) - invalid_percentage = ((tgt_data[tgt_text_columns] == INVALID_VALUE).sum() / len(tgt_data) * 100.0).map( + invalid_percentage = ((tgt_data[tgt_stats["columns"].keys()] == INVALID_VALUE).sum() / len(tgt_data) * 100.0).map( "{:.2f}%".format ) + + for col in tgt_stats["columns"].keys(): + col_stats = tgt_stats["columns"][col] + if col_stats["encoding_type"] == ModelEncodingType.language_numeric: + tgt_data[col] = decode_language_numeric(tgt_data[col], col_stats) + elif col_stats["encoding_type"] == ModelEncodingType.language_datetime: + tgt_data[col] = decode_language_datetime(tgt_data[col], col_stats) + elif col_stats["encoding_type"] == ModelEncodingType.language_categorical: + tgt_data[col] = decode_language_categorical(tgt_data[col], col_stats) + else: + tgt_data[col] = decode_text(tgt_data[col], col_stats) + _LOG.info(f"percentage of invalid values: {invalid_percentage.to_dict()}") _LOG.info(f"decoded {tgt_data.shape} from {len(buffer.buffer)} batches in {time.time() - t0:.2f}s") return tgt_data @@ -119,6 +135,7 @@ def generate( batch_size: int | None = None, sampling_temperature: float = 1.0, sampling_top_p: float = 1.0, + rare_category_replacement_method: RareCategoryReplacementMethod | str = RareCategoryReplacementMethod.constant, device: torch.device | str | None = None, workspace_dir: str | Path = "engine-ws", update_progress: ProgressCallback | None = None, @@ -162,7 +179,6 @@ def tqdm_disabled(): if has_context: ctx_stats = workspace.ctx_stats.read() - ctx_columns = list(ctx_stats["columns"].keys()) ctx_primary_key = ctx_stats["keys"].get("primary_key") # ensure ctx_data exists @@ -187,11 +203,11 @@ def tqdm_disabled(): sample_size = len(ctx_data) _LOG.info(f"{sample_size=}") else: + ctx_stats = None # create on-the-fly context if sample_size is None: trn_sample_size = tgt_stats["no_of_training_records"] + tgt_stats["no_of_validation_records"] sample_size = trn_sample_size if sample_size is None else sample_size - ctx_columns = [] ctx_primary_key = tgt_context_key = DUMMY_CONTEXT_KEY ctx_data = pd.DataFrame({ctx_primary_key: range(sample_size)}) @@ -213,7 +229,7 @@ def tqdm_disabled(): return # encode context data - encoded_ctx_data = encode_df(ctx_df=ctx_data, ctx_columns=ctx_columns) + encoded_ctx_data = encode_df(ctx_df=ctx_data, ctx_stats=ctx_stats) # estimate max new tokens based on char length of original data; consider JSON overhead max_new_tokens = estimate_max_tokens(tgt_stats) @@ -247,7 +263,6 @@ def tqdm_disabled(): # prepare seed data for clean consumption by formatron seed_data = prepare_seed_for_formatron(seed_data, engine.tokenizer) seeded_tgt_columns = seed_data.columns.to_list() - unseeded_tgt_columns = [c for c in tgt_text_columns if c not in seeded_tgt_columns] total_tokenize_fn_time = 0 total_logits_processor_build_time = 0 @@ -259,7 +274,9 @@ def tqdm_disabled(): if enforce_json_output and len(seeded_tgt_columns) == 0: t0 = time.time() - formatter_builders = get_formatter_builders(size=batch_size, unseeded_fields=unseeded_tgt_columns) + formatter_builders = get_formatter_builders( + size=batch_size, stats=tgt_stats, rare_category_replacement_method=rare_category_replacement_method + ) engine.initialize_logits_processors(formatter_builders, formatron_vocab_processors) total_logits_processor_build_time += time.time() - t0 @@ -279,7 +296,8 @@ def tqdm_disabled(): # some columns are seeded, so we need to create a new logits processor for each batch formatter_builders = get_formatter_builders( seed_df=sample_seed_batch, - unseeded_fields=unseeded_tgt_columns, + stats=tgt_stats, + rare_category_replacement_method=rare_category_replacement_method, ) engine.initialize_logits_processors(formatter_builders, formatron_vocab_processors) total_logits_processor_build_time += time.time() - t0 @@ -295,7 +313,7 @@ def tqdm_disabled(): buffer.add((outputs, ctx_keys, sample_seed_batch)) if buffer.is_full(): decoded_data = decode_buffered_samples( - buffer, engine.tokenizer, tgt_text_columns, tgt_context_key, max_new_tokens + buffer, engine.tokenizer, tgt_stats, tgt_context_key, max_new_tokens ) persist_data_part( decoded_data, @@ -307,9 +325,7 @@ def tqdm_disabled(): samples_processed += len(ctx_batch) if not buffer.is_empty(): - decoded_data = decode_buffered_samples( - buffer, engine.tokenizer, tgt_text_columns, tgt_context_key, max_new_tokens - ) + decoded_data = decode_buffered_samples(buffer, engine.tokenizer, tgt_stats, tgt_context_key, max_new_tokens) persist_data_part( decoded_data, output_path, diff --git a/mostlyai/engine/_language/tokenizer_utils.py b/mostlyai/engine/_language/tokenizer_utils.py index ff2cba0..64c13e5 100644 --- a/mostlyai/engine/_language/tokenizer_utils.py +++ b/mostlyai/engine/_language/tokenizer_utils.py @@ -19,13 +19,19 @@ from transformers import DataCollatorForLanguageModeling, BatchEncoding, PreTrainedTokenizerFast, LlamaTokenizerFast from transformers.data.data_collator import pad_without_fast_tokenizer_warning, _torch_collate_batch +from mostlyai.engine.domain import ModelEncodingType + ################# ### TOKENIZER ### ################# -def train_tokenizer(training_iterator: Iterator | list | None = None, tokenizer_kwargs=None): +def train_tokenizer( + training_iterator: Iterator | list | None = None, + tokenizer_kwargs: dict[str, Any] | None = None, + tgt_stats: dict[str, Any] | None = None, +): if tokenizer_kwargs is None: tokenizer_kwargs = {} from tokenizers import Tokenizer, decoders @@ -46,10 +52,26 @@ def train_tokenizer(training_iterator: Iterator | list | None = None, tokenizer_ MIN_FREQ_MERGE = 20 VOCAB_SIZE = 5000 + # add initial alphabet for numeric and datetime columns if needed + has_numeric_columns = any( + col_stats["encoding_type"] == ModelEncodingType.language_numeric for col_stats in tgt_stats["columns"].values() + ) + has_datetime_columns = any( + col_stats["encoding_type"] == ModelEncodingType.language_datetime for col_stats in tgt_stats["columns"].values() + ) + initial_alphabet = set() + if has_numeric_columns: + # FIXME: maybe the set can be more fine-grained based on max_scale in stats + initial_alphabet |= {str(i) for i in range(10)} | {".", "-", "+", "e", "E"} + if has_datetime_columns: + initial_alphabet |= {str(i) for i in range(10)} | {".", "-", ":", "T", "Z"} + initial_alphabet = list(initial_alphabet) + # Builds a BPE raw_tokenizer, and optionally trains it based on provided text training_iterator = training_iterator or [] # allow easy training skip raw_tokenizer = Tokenizer(BPE(unk_token=special_tokens["unk_token"])) trainer = BpeTrainer( + initial_alphabet=initial_alphabet, special_tokens=SPECIAL_TOKENS, min_frequency=MIN_FREQ_MERGE, vocab_size=VOCAB_SIZE, diff --git a/mostlyai/engine/_language/training.py b/mostlyai/engine/_language/training.py index 4a21da8..15eebc4 100644 --- a/mostlyai/engine/_language/training.py +++ b/mostlyai/engine/_language/training.py @@ -38,7 +38,6 @@ from torch.utils.data import DataLoader from mostlyai.engine._common import ( - STRING, ProgressCallback, ProgressCallbackWrapper, TABLE_COLUMN_INFIX, @@ -272,7 +271,7 @@ def train( raw_dataset = load_dataset("parquet", data_files=data_files) def shuffle_tgt_columns(x): - x_tgt = pd.DataFrame([json.loads(x.pop("tgt"))], dtype=STRING) # convert to DataFrame + x_tgt = pd.DataFrame([json.loads(x.pop("tgt"))]) # convert to DataFrame x_tgt = x_tgt.sample(frac=1, axis=1) # shuffle columns x_tgt = row_to_json( x_tgt.add_prefix("tgt" + TABLE_COLUMN_INFIX).squeeze(axis=0), is_target=True @@ -352,7 +351,7 @@ def concat_prompt_and_response(x): for i in range(0, len(content_dataset["train"]), 1_000) ) # train a custom tokenizer and convert it to a LlamaTokenizerFast object - tokenizer = train_tokenizer(tokenizer_train_iter, tokenizer_kwargs=tokenizer_args) + tokenizer = train_tokenizer(tokenizer_train_iter, tokenizer_kwargs=tokenizer_args, tgt_stats=tgt_stats) model_config = LSTMFromScratchConfig(vocab_size=len(tokenizer), with_dp=with_dp) model = LSTMFromScratchLMHeadModel(model_config).to(device) else: diff --git a/mostlyai/engine/analysis.py b/mostlyai/engine/analysis.py index 1094c61..2a27b77 100644 --- a/mostlyai/engine/analysis.py +++ b/mostlyai/engine/analysis.py @@ -41,6 +41,11 @@ ProgressCallback, ProgressCallbackWrapper, ) +from mostlyai.engine._encoding_types.language.datetime import ( + analyze_reduce_language_datetime, + analyze_language_datetime, +) +from mostlyai.engine._encoding_types.language.numeric import analyze_language_numeric, analyze_reduce_language_numeric from mostlyai.engine._encoding_types.tabular.categorical import ( analyze_categorical, analyze_reduce_categorical, @@ -66,6 +71,10 @@ analyze_text, analyze_reduce_text, ) +from mostlyai.engine._encoding_types.language.categorical import ( + analyze_language_categorical, + analyze_reduce_language_categorical, +) from mostlyai.engine.domain import ModelEncodingType from mostlyai.engine._workspace import ( @@ -84,6 +93,9 @@ ModelEncodingType.tabular_numeric_binned, ModelEncodingType.tabular_datetime, ModelEncodingType.tabular_datetime_relative, + ModelEncodingType.language_categorical, + ModelEncodingType.language_numeric, + ModelEncodingType.language_datetime, ) @@ -313,22 +325,7 @@ def _analyze_reduce( column: column_stats.get("encoding_type") for column, column_stats in stats_list[0]["columns"].items() } - # build mapping of original column name to ARGN table and column identifiers - def get_table(qualified_column_name: str) -> str: - # column names are assumed to be