Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: LANGUAGE encoding types #29

Merged
merged 60 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
2aff54a
feat: add support for language categorical encoding and analysis, wip1
radurogojanumai Feb 5, 2025
1c7d547
wip
lukaszkolodziejczyk Feb 5, 2025
a16a6a1
language_numeric
lukaszkolodziejczyk Feb 6, 2025
827679a
language_numeric
lukaszkolodziejczyk Feb 6, 2025
2e10629
language_numeric
lukaszkolodziejczyk Feb 6, 2025
e4e0e4f
language_numeric
lukaszkolodziejczyk Feb 6, 2025
a9b681d
add test and beginnings of numeric
andre-mostly Feb 6, 2025
ab5ce5a
language_numeric
lukaszkolodziejczyk Feb 6, 2025
52c8303
fix
lukaszkolodziejczyk Feb 6, 2025
ce61fed
remove unnecessary
lukaszkolodziejczyk Feb 6, 2025
e2e23a4
change test to AMD, fix max_decimal
andre-mostly Feb 6, 2025
988d3c8
fix decimal
andre-mostly Feb 6, 2025
e3dc810
max_scale
lukaszkolodziejczyk Feb 6, 2025
926ec2b
encode string/numeric
andre-mostly Feb 6, 2025
abda5d5
placeholder datetime
andre-mostly Feb 6, 2025
337003c
add dt test asserts
andre-mostly Feb 6, 2025
c668e67
datetime
lukaszkolodziejczyk Feb 6, 2025
8d5e764
add coercion
andre-mostly Feb 6, 2025
5913da9
fix invalid to nan for numeric
andre-mostly Feb 7, 2025
87fb34f
coerce datetimes to valid dates, add datetime grammar, change test to…
andre-mostly Feb 7, 2025
548f11a
comments and refactor
andre-mostly Feb 7, 2025
31fea13
tiny refactor
michdr Feb 7, 2025
6228782
add back categorical test
andre-mostly Feb 7, 2025
fea28fb
Merge branch 'main' into msd-965-language-enctypes
michdr Feb 10, 2025
386d976
Simpler categorical analyze / analyze reduce; rare category protectio…
lukaszkolodziejczyk Feb 10, 2025
6afe58f
expose rare_category_replacement_method for LANGUAGE
lukaszkolodziejczyk Feb 10, 2025
f7948bc
MSD-XXX: add initial alphabets to untrained tokenizer if needed (#33)
shuangwu5 Feb 10, 2025
ac2c1cc
simplify numeric
lukaszkolodziejczyk Feb 11, 2025
65b9dd6
fix tests (#34)
shuangwu5 Feb 11, 2025
a1bdade
refactor temp_formatron.py (#35)
shuangwu5 Feb 11, 2025
f3ec656
fix several _decode_numeric and _decode_datetime FIXMEs
michdr Feb 11, 2025
53b5e4a
ruff
michdr Feb 11, 2025
c4fb39b
Extreme value protection for LANGUAGE_NUMERIC (#36)
lukaszkolodziejczyk Feb 11, 2025
7456760
extreme value protection for datetimes
lukaszkolodziejczyk Feb 11, 2025
28516f6
build: uv run without re-syncing the environment (#37)
shuangwu5 Feb 11, 2025
5d8011f
temp fix for datetime validation
michdr Feb 11, 2025
4eccc0d
enhance test_categorical_numeric_datetime
michdr Feb 12, 2025
d52772e
enable all models in test_categorical_numeric_datetime
michdr Feb 12, 2025
e343ace
fix _decode_datetime
michdr Feb 12, 2025
9453080
feat: constrain numeric and simplify datetime (#38)
andre-mostly Feb 13, 2025
d66670d
refactor: move language decode functions + tabular encoding types uni…
michdr Feb 13, 2025
ccbb1e0
add decode_text
michdr Feb 13, 2025
c643d25
ruff
michdr Feb 13, 2025
8002963
Merge branch 'main' into msd-965-language-enctypes
michdr Feb 13, 2025
88b5772
restrict number to correct number of decimal points in grammar, negat…
andre-mostly Feb 13, 2025
8a3506c
fix datetime pattern
andre-mostly Feb 13, 2025
6ee92e1
add unit tests + improve decode_numeric
michdr Feb 13, 2025
9c20ee2
fix monkey patch
andre-mostly Feb 13, 2025
089cf62
remove temp_formatron.py and move code into formatron_utils.py
andre-mostly Feb 13, 2025
6692258
fix numeric training
lukaszkolodziejczyk Feb 13, 2025
da28b5c
make max5, min5 maintain numeric dtype (int for int and float for flo…
andre-mostly Feb 13, 2025
7b4bc2f
added description for new enc types
mplatzer Feb 14, 2025
6e00e98
re-add disabled models in test_categorical_numeric_datetime
michdr Feb 14, 2025
743ec3e
fix comments
andre-mostly Feb 17, 2025
77e55c0
refactor analyze
lukaszkolodziejczyk Feb 17, 2025
33a3b1e
LANGUAGE CATEGORICAL (#43)
lukaszkolodziejczyk Feb 17, 2025
7875908
n_jobs
lukaszkolodziejczyk Feb 17, 2025
be41a73
kill examples/language_encoding_types.ipynb
lukaszkolodziejczyk Feb 17, 2025
d0452c4
datetime tests
lukaszkolodziejczyk Feb 17, 2025
2548431
numeric tests
lukaszkolodziejczyk Feb 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/run-tests-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ jobs:
run: uv sync --frozen --extra cpu

- name: Run | Tests -> unit
run: uv run pytest tests/unit
run: uv run --no-sync pytest tests/unit

- name: Build mkdocs
run: uv run mkdocs build --strict
run: uv run --no-sync mkdocs build --strict

- name: Run tests -> end_to_end -> sequential
run: uv run pytest tests/end_to_end/test_tabular_sequential.py
run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py

- name: Run tests -> end_to_end -> sequential context
run: uv run pytest tests/end_to_end/test_tabular_sequential_context.py
run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py

run-tests-cpu-end-to-end-nonsequential:
runs-on: ubuntu-latest
Expand All @@ -66,4 +66,4 @@ jobs:
run: uv sync --frozen --extra cpu --no-group docs

- name: Run tests -> end_to_end all except sequential
run: uv run pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/
run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/
6 changes: 3 additions & 3 deletions .github/workflows/run-tests-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ jobs:
run: nvidia-smi

- name: Run tests -> end_to_end -> sequential
run: uv run pytest tests/end_to_end/test_tabular_sequential.py
run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential.py

- name: Run tests -> end_to_end -> sequential context
run: uv run pytest tests/end_to_end/test_tabular_sequential_context.py
run: uv run --no-sync pytest tests/end_to_end/test_tabular_sequential_context.py

- name: Run tests -> end_to_end all except sequential
run: uv run pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/
run: uv run --no-sync pytest --ignore=tests/end_to_end/test_tabular_sequential.py --ignore=tests/end_to_end/test_tabular_sequential_context.py tests/end_to_end/
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ install: # Install dependencies

.PHONY: lint
lint: ## Run lints
uv run pre-commit run --all-files
uv run --no-sync pre-commit run --all-files

.PHONY: test
test: ## Run tests
uv run pytest
uv run --no-sync pytest

.PHONY: all
all: clean install lint test ## Run the commands: clean install lint test
Expand Down
75 changes: 75 additions & 0 deletions mostlyai/engine/_encoding_types/language/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2025 MOSTLY AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Categorical encoding for language models.
"""

import numpy as np
import pandas as pd

from mostlyai.engine._common import safe_convert_string, STRING

CATEGORICAL_UNKNOWN_TOKEN = "_RARE_"


def analyze_language_categorical(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict:
values = safe_convert_string(values)
# count distinct root_keys per categorical value for rare-category protection
df = pd.concat([root_keys, values], axis=1)
cnt_values = df.groupby(values.name)[root_keys.name].nunique().to_dict()
stats = {"has_nan": sum(values.isna()) > 0, "cnt_values": cnt_values}
return stats


def analyze_reduce_language_categorical(stats_list: list[dict], value_protection: bool = True) -> dict:
# sum up all counts for each categorical value
cnt_values: dict[str, int] = {}
for item in stats_list:
for value, count in item["cnt_values"].items():
cnt_values[value] = cnt_values.get(value, 0) + count
# create alphabetically sorted list of non-rare categories
known_categories = [k for k in sorted(cnt_values.keys())]
if value_protection:
# stochastic threshold for rare categories
rare_min = 5 + int(3 * np.random.uniform())
else:
rare_min = 0
categories = [k for k in known_categories if cnt_values[k] >= rare_min]
no_of_rare_categories = len(known_categories) - len(categories)
# add None to categories, if any are present
if any([j["has_nan"] for j in stats_list]):
categories = [None] + categories
# add special token for UNKNOWN categories at first position
if no_of_rare_categories > 0:
categories = [CATEGORICAL_UNKNOWN_TOKEN] + categories
stats = {"no_of_rare_categories": no_of_rare_categories, "categories": categories}
return stats


def encode_language_categorical(values: pd.Series, stats: dict) -> pd.Series:
values = safe_convert_string(values)
values = values.copy()
known_categories = stats["categories"]
mask = ~values.isin(known_categories)
if None in known_categories:
mask &= ~pd.isna(values)
values[mask] = CATEGORICAL_UNKNOWN_TOKEN
return values


def decode_language_categorical(x: pd.Series, col_stats: dict[str, str]) -> pd.Series:
x = x.astype(STRING)
allowed_categories = col_stats.get("categories", [])
return x.where(x.isin(allowed_categories), other=None)
143 changes: 143 additions & 0 deletions mostlyai/engine/_encoding_types/language/datetime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2025 MOSTLY AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import calendar

import numpy as np
import pandas as pd

from mostlyai.engine._common import safe_convert_datetime


def analyze_language_datetime(values: pd.Series, root_keys: pd.Series, _: pd.Series | None = None) -> dict:
values = safe_convert_datetime(values)
df = pd.concat([root_keys, values], axis=1)
# determine lowest/highest values by root ID, and return Top 10
min_dates = df.groupby(root_keys.name)[values.name].min().dropna()
min11 = min_dates.sort_values(ascending=True).head(11).astype(str).tolist()
max_dates = df.groupby(root_keys.name)[values.name].max().dropna()
max11 = max_dates.sort_values(ascending=False).head(11).astype(str).tolist()
# determine if there are any NaN values
has_nan = bool(values.isna().any())
# return stats
stats = {
"has_nan": has_nan,
"min11": min11,
"max11": max11,
}
return stats


def analyze_reduce_language_datetime(stats_list: list[dict], value_protection: bool = True) -> dict:
# check if there are missing values
has_nan = any([j["has_nan"] for j in stats_list])
# determine min / max 5 values to map too low / too high values to
min11 = sorted([v for min11 in [j["min11"] for j in stats_list] for v in min11], reverse=False)[:11]
max11 = sorted([v for max11 in [j["max11"] for j in stats_list] for v in max11], reverse=True)[:11]
if value_protection:
# extreme value protection - discard lowest/highest 5 values
if len(min11) < 11 or len(max11) < 11:
# less than 11 subjects with non-NULL values; we need to protect all
min5 = []
max5 = []
else:
min5 = [str(v) for v in min11[5:10]] # drop 1 to 5th lowest; keep 6th to 10th lowest
max5 = [str(v) for v in max11[5:10]] # drop 1 to 5th highest; keep 6th to 10th highest
else:
min5 = min11[0:4]
max5 = max11[0:4]
stats = {
"has_nan": has_nan,
"min5": min5,
"max5": max5,
}
return stats


def encode_language_datetime(values: pd.Series, stats: dict, _: pd.Series | None = None) -> pd.Series:
# convert
values = safe_convert_datetime(values)
values = values.copy()
# reset index, as `values.mask` can throw errors for misaligned indices
values.reset_index(drop=True, inplace=True)
# replace extreme values with randomly sampled 5-th to 10-th largest/smallest values
min5 = stats["min5"] if len(stats["min5"]) > 0 else [0]
max5 = stats["max5"] if len(stats["max5"]) > 0 else [0]
min5 = pd.Series(min5, dtype=values.dtype)
max5 = pd.Series(max5, dtype=values.dtype)
values.mask(
values < min5[0],
min5.sample(n=len(values), replace=True, ignore_index=True),
inplace=True,
)
values.mask(
values > max5[0],
max5.sample(n=len(values), replace=True, ignore_index=True),
inplace=True,
)
return values


def _clip_datetime(x: pd.Series, min5: list, max5: list) -> pd.Series:
x_dt = pd.to_datetime(x, errors="coerce")
min_arr = pd.to_datetime(min5).to_numpy(dtype="datetime64[ns]")
max_arr = pd.to_datetime(max5).to_numpy(dtype="datetime64[ns]")
n = len(x_dt)
random_mins = np.random.choice(min_arr, size=n)
random_maxs = np.random.choice(max_arr, size=n)
clipped = np.minimum(np.maximum(x_dt.to_numpy(dtype="datetime64[ns]"), random_mins), random_maxs)
return pd.Series(clipped, index=x.index)


def decode_language_datetime(x: pd.Series, col_stats: dict[str, str]) -> pd.Series:
x = x.where(~x.isin(["", "_INVALID_"]), np.nan)

valid_mask = (
x.str.len().ge(10)
& x.str.slice(0, 4).str.isdigit()
& x.str.slice(5, 7).str.isdigit()
& x.str.slice(8, 10).str.isdigit()
)
if valid_mask.sum() > 0: # expected "YYYY-MM-DD" prefix
# handle the date portion, ensuring validity
years = x[valid_mask].str.slice(0, 4).astype(int)
months = x[valid_mask].str.slice(5, 7).astype(int)
days = x[valid_mask].str.slice(8, 10).astype(int)

# clamp days according to maximum possible day of the month of a given year
last_days = np.array([calendar.monthrange(y, m)[1] for y, m in zip(years, months)])
clamped_days = np.minimum(days, last_days)

# rebuild the date portion
new_date = (
years.astype(str).str.zfill(4)
+ "-"
+ months.astype(str).str.zfill(2)
+ "-"
+ pd.Series(clamped_days, index=years.index).astype(str).str.zfill(2)
)

# handle the time portion, ensuring validity
remainder = x[valid_mask].str.slice(10)

time_regex = r"^[ T]?(\d{2}:\d{2}:\d{2}(?:\.\d+)?)"
valid_time = remainder.str.extract(time_regex, expand=False)
valid_time = valid_time.fillna("00:00:00")
valid_time = " " + valid_time

new_date = new_date + valid_time
x.loc[valid_mask] = new_date

x = pd.to_datetime(x, errors="coerce")
x = _clip_datetime(x, col_stats["min5"], col_stats["max5"])
return x.astype("datetime64[ns]")
Loading