Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug due to precision in CategoricalMatrix._get_col_stds #391

Merged

Conversation

mlondschien
Copy link
Contributor

This fixes a bug in CategoricalMatrix._get_col_stds: We observed that if using float32 precision, we'd sometimes get nans as standard deviations for constant categorical columns. This would then lead to nan in P1 and P2 here.

To reproduce:

import numpy as np
import tabmat
import pandas as pd
import glum

n = 1_234_456
rng = np.random.default_rng(0)
p = 0.00001

x = rng.choice([0, 1], p=[p, 1-p], size=n)
cat = pd.Categorical.from_codes(x, categories=np.array(range(3), dtype=np.int32))
mat = tabmat.CategoricalMatrix(cat, column_name="cat", dtype=np.float32)

y = np.random.choice([0, 1], size=n).astype(np.float32)

weights = np.ones(n, dtype=np.float32)
weights /= weights.sum()

X, means, col_stds = mat.standardize(weights, True, True)
print(f"means: {means}, col_stds: {col_stds}")

glm = glum.GeneralizedLinearRegressor(family="binomial", alpha=0.01, l1_ratio=0.5)
glm.fit(mat, y)
/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/tabmat/categorical_matrix.py:607: RuntimeWarning: invalid value encountered in sqrt
  return np.sqrt(mean - col_means**2)
means: [9.7208804e-06 1.0003393e+00 0.0000000e+00], col_stds: [0.00311782        nan 0.        ]
/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/tabmat/categorical_matrix.py:607: RuntimeWarning: invalid value encountered in sqrt
  return np.sqrt(mean - col_means**2)
Traceback (most recent call last):
  File "/Users/mlondschien/code/projects2024-icu-foundation/tmp_tabmat.py", line 22, in <module>
    glm.fit(mat, y)
  File "/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/glum/_glm.py", line 3244, in fit
    coef = self._solve(
  File "/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/glum/_glm.py", line 1113, in _solve
    coef, self.n_iter_, self._n_cycles, self.diagnostics_ = _irls_solver(
  File "/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/glum/_solvers.py", line 288, in _irls_solver
    state.eta, state.mu, state.obj_val, coef_P2 = _update_predictions(
  File "/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/glum/_solvers.py", line 638, in _update_predictions
    return eta_mu_objective(
  File "/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/glum/_solvers.py", line 671, in eta_mu_objective
    obj_val += linalg.norm(P1 * coef[intercept_offset:], ord=1)
  File "/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/scipy/linalg/_misc.py", line 146, in norm
    a = np.asarray_chkfinite(a)
  File "/Users/mlondschien/mambaforge/envs/icufm/lib/python3.10/site-packages/numpy/lib/function_base.py", line 628, in asarray_chkfinite
    raise ValueError(
ValueError: array must not contain infs or NaNs

Not sure where to test.

Co-authored-by: Luca Bittarello <15511539+lbittarello@users.noreply.github.com>
@lbittarello lbittarello merged commit 24782f8 into Quantco:main Sep 19, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants