Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ jobs:
benchmark:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'

Expand Down
20 changes: 10 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ jobs:
python-version: [3.8, 3.9, '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -29,9 +29,9 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand All @@ -44,9 +44,9 @@ jobs:
type-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand All @@ -59,9 +59,9 @@ jobs:
coverage:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand All @@ -80,9 +80,9 @@ jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/code-quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ jobs:
code-quality:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ jobs:
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
with:
ref: docs # Always check out the docs branch

- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ jobs:
contents: write

steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'

Expand All @@ -34,7 +34,7 @@ jobs:
needs: release
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for all tags and branches
- name: Generate release notes
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/security.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ jobs:
security:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'

Expand Down
15 changes: 8 additions & 7 deletions pymars/_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import logging
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(self, earth_model: Earth):
self.y_train = None
self.n_samples = 0
self.n_features = 0
self.current_basis_functions: list[BasisFunction] = []
self.current_basis_functions: List[BasisFunction] = []
self.current_B_matrix = None
self.current_coefficients = None
self.current_rss = np.inf
Expand All @@ -53,7 +54,7 @@ def __init__(self, earth_model: Earth):

def _calculate_rss_and_coeffs(
self, B_matrix: np.ndarray, y: np.ndarray, *, drop_nan_rows: bool = True
) -> tuple[float, np.ndarray | None, int]:
) -> Tuple[float, Optional[np.ndarray], int]:
if B_matrix is None or B_matrix.shape[1] == 0:
mean_y = np.mean(y)
rss = np.sum((y - mean_y)**2)
Expand Down Expand Up @@ -93,7 +94,7 @@ def _calculate_rss_and_coeffs(
except np.linalg.LinAlgError:
return np.inf, None, num_valid_rows

def _build_basis_matrix(self, X_processed: np.ndarray, basis_functions: list[BasisFunction]) -> np.ndarray:
def _build_basis_matrix(self, X_processed: np.ndarray, basis_functions: List[BasisFunction]) -> np.ndarray:
if not basis_functions:
return np.empty((X_processed.shape[0], 0))

Expand All @@ -108,7 +109,7 @@ def _build_basis_matrix(self, X_processed: np.ndarray, basis_functions: list[Bas
return B_matrix

def run(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,
missing_mask: np.ndarray, X_fit_original: np.ndarray) -> tuple[list[BasisFunction], np.ndarray]:
missing_mask: np.ndarray, X_fit_original: np.ndarray) -> Tuple[List[BasisFunction], np.ndarray]:
self.X_train = X_fit_processed
self.y_train = y_fit
self.missing_mask = missing_mask
Expand Down Expand Up @@ -225,7 +226,7 @@ def run(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,

return self.current_basis_functions, self.current_coefficients

def _calculate_gcv_for_basis_set(self, basis_functions: list[BasisFunction]) -> tuple[float | None, np.ndarray | None]:
def _calculate_gcv_for_basis_set(self, basis_functions: List[BasisFunction]) -> Tuple[Optional[float], Optional[np.ndarray]]:
if not basis_functions:
# This implies an intercept-only model for GCV calculation purposes
rss_intercept_only = np.sum((self.y_train - np.mean(self.y_train))**2)
Expand Down Expand Up @@ -348,8 +349,8 @@ def _get_allowable_knot_values(self, X_col_original_for_var: np.ndarray, parent_
minspan_countdown = max(0, minspan_abs - 1)
return np.array(final_allowable_knots)

def _generate_candidates(self) -> list[tuple[BasisFunction, BasisFunction | None]]:
candidate_additions: list[tuple[BasisFunction, BasisFunction | None]] = []
def _generate_candidates(self) -> List[Tuple[BasisFunction, Optional[BasisFunction]]]:
candidate_additions: List[Tuple[BasisFunction, Optional[BasisFunction]]] = []
for parent_bf in self.current_basis_functions:
if parent_bf.degree() + 1 > self.model.max_degree: continue
parent_involved_vars = parent_bf.get_involved_variables()
Expand Down
13 changes: 7 additions & 6 deletions pymars/_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import logging
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -35,10 +36,10 @@ def __init__(self, earth_model: Earth):
self.X_fit_original = None

self.best_gcv_so_far = np.inf
self.best_basis_functions_so_far: list[BasisFunction] = []
self.best_basis_functions_so_far: List[BasisFunction] = []
self.best_coeffs_so_far: np.ndarray = None

def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> tuple[float, np.ndarray | None, int]:
def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> Tuple[float, Optional[np.ndarray], int]:
"""
Calculates RSS, coefficients, and num_valid_rows, considering NaNs in B_matrix.
y_data is assumed finite.
Expand Down Expand Up @@ -78,7 +79,7 @@ def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) ->
)
return np.inf, None, num_valid_rows

def _build_basis_matrix(self, X_data: np.ndarray, basis_functions: list[BasisFunction],
def _build_basis_matrix(self, X_data: np.ndarray, basis_functions: List[BasisFunction],
missing_mask: np.ndarray) -> np.ndarray:
"""
Constructs the basis matrix B from X_data (which is X_processed)
Expand All @@ -97,7 +98,7 @@ def _build_basis_matrix(self, X_data: np.ndarray, basis_functions: list[BasisFun

def _compute_gcv_for_subset(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,
missing_mask: np.ndarray, X_fit_original: np.ndarray,
basis_subset: list[BasisFunction]) -> tuple[float | None, float | None, np.ndarray | None]:
basis_subset: List[BasisFunction]) -> Tuple[Optional[float], Optional[float], Optional[np.ndarray]]:
"""
Computes GCV, RSS, and coefficients for a given subset of basis functions.
Returns (gcv, rss, coeffs).
Expand Down Expand Up @@ -155,8 +156,8 @@ def _compute_gcv_for_subset(self, X_fit_processed: np.ndarray, y_fit: np.ndarray

def run(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,
missing_mask: np.ndarray, X_fit_original: np.ndarray,
initial_basis_functions: list[BasisFunction],
initial_coefficients: np.ndarray) -> tuple[list[BasisFunction], np.ndarray, float]:
initial_basis_functions: List[BasisFunction],
initial_coefficients: np.ndarray) -> Tuple[List[BasisFunction], np.ndarray, float]:

self.X_train = X_fit_processed
self.y_train = y_fit.ravel()
Expand Down
13 changes: 7 additions & 6 deletions pymars/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import logging
from typing import List

import numpy as np

Expand All @@ -29,13 +30,13 @@ def __init__(self, X, y, earth_model_instance):
# Pruning pass tracking
# These will store the sequence of models considered during pruning.
# Each element corresponds to a model of a certain size.
self.pruning_trace_basis_functions_: list[list[BasisFunction]] = []
self.pruning_trace_coeffs_: list[np.ndarray] = []
self.pruning_trace_gcv_: list[float] = []
self.pruning_trace_rss_: list[float] = []
self.pruning_trace_basis_functions_: List[List[BasisFunction]] = []
self.pruning_trace_coeffs_: List[np.ndarray] = []
self.pruning_trace_gcv_: List[float] = []
self.pruning_trace_rss_: List[float] = []

# Final selected model details (can be set after pruning)
self.final_basis_: list[BasisFunction] = None
self.final_basis_: List[BasisFunction] = None
self.final_coeffs_ = None
self.final_gcv_ = None
self.final_rss_ = None
Expand All @@ -47,7 +48,7 @@ def log_forward_pass_step(self, basis_functions, coefficients, rss):
self.fwd_coeffs_.append(np.copy(coefficients))
self.fwd_rss_.append(rss)

def log_pruning_step(self, basis_functions: list['BasisFunction'],
def log_pruning_step(self, basis_functions: List['BasisFunction'],
coefficients: np.ndarray, gcv: float, rss: float): # Renamed
"""
Log a model state (basis functions, coefficients, GCV, RSS) encountered
Expand Down
3 changes: 2 additions & 1 deletion pymars/earth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
The main Earth class, coordinating the model fitting process.
"""
import logging
from typing import List, Optional

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(self, max_degree: int = 1, penalty: float = 3.0, max_terms: int = N
allow_linear: bool = True,
allow_missing: bool = False, # New parameter
feature_importance_type: str = None,
categorical_features: list[int] = None
categorical_features: Optional[List[int]] = None
# TODO: Consider other py-earth params
):
super().__init__()
Expand Down