Skip to content

RBF Regression for Free Energy Surface Fitting#39

Merged
craabreu merged 10 commits intomainfrom
rbf-regression
Nov 13, 2025
Merged

RBF Regression for Free Energy Surface Fitting#39
craabreu merged 10 commits intomainfrom
rbf-regression

Conversation

@craabreu
Copy link
Owner

Radial Basis Function Regression for Free Energy Surface Fitting

Summary

This PR introduces a new ForceMatchingRegressor class that uses radial basis functions (RBF) to fit free energy surfaces from position-force pairs. The implementation supports both periodic and non-periodic dynamical variables, uses PyTorch Lightning for training, and includes comprehensive unit tests.

Key Features

1. Per-Dimension Kernel Bandwidths

  • Each RBF kernel now has separate bandwidth parameters $\sigma_{m,k}$ for each dimension $k$
  • Enables anisotropic kernels that can adapt to different length scales per dimension
  • More flexible than single bandwidth per kernel

2. Periodic and Non-Periodic Variable Support

  • Automatically detects periodic vs non-periodic variables from DynamicalVariable bounds
  • Uses appropriate distance functions:
    • Periodic: $\delta^2 = (L/\pi)^2 \sin^2(\pi d/L)$ where $L$ is the period
    • Non-periodic: $\delta^2 = d^2$ (Euclidean distance)
  • Supports mixed periodic/non-periodic variable spaces

3. Dynamical Variable Integration

  • Uses DynamicalVariable objects to extract bounds and periodicity information
  • Automatically converts to MD units for consistency
  • Centers are initialized within the actual bounds of each variable
  • Parameter extraction wraps periodic centers according to bounds

4. Training Features

  • Model Checkpointing: Automatically saves and loads the best model (lowest validation loss)
  • Early Stopping: Prevents overfitting with configurable patience
  • Validation Set: Optional validation split for monitoring generalization
  • Temporary Checkpoints: Uses TemporaryDirectory for automatic cleanup

5. Type Safety

  • Complete type hints throughout all methods
  • Type hints removed from docstrings (auto-generated by Sphinx)
  • Proper handling of device placement (CPU/GPU/MPS)

API

from openxps import ForceMatchingRegressor, DynamicalVariable
from openxps.bounds import PeriodicBounds, NoBounds
from openmm import unit

# Define dynamical variables
dvs = [
    DynamicalVariable("phi", unit.radian, mass, PeriodicBounds(-np.pi, np.pi, unit.radian)),
    DynamicalVariable("r", unit.nanometer, mass, NoBounds()),
]

# Create regressor
regressor = ForceMatchingRegressor(
    dynamical_variables=dvs,
    num_kernels=256,
    initial_bandwidth=1.0,
    validation_fraction=0.1,
    patience=10,
)

# Fit model
regressor.fit(positions, forces)

# Predict potential
potentials = regressor.predict(new_positions)

# Get parameters
centers, sigmas, weights = regressor.get_parameters()

Mathematical Formulation

The potential is approximated as:

$$U(\mathbf{s}) = \sum_{m=1}^M w_m \exp\left(-\frac{1}{2} \sum_{k=1}^d \frac{\delta_k^2(s_k - c_{m,k})}{\sigma_{m,k}^2}\right)$$

where:

  • $w_m$: weights for each kernel $m$
  • $c_{m,k}$: centers for kernel $m$, dimension $k$
  • $\sigma_{m,k}$: bandwidths for kernel $m$, dimension $k$
  • $\delta_k(x)$: distance function (periodic or Euclidean) for dimension $k$

The model is trained by minimizing the mean squared error between predicted and actual forces:

$$L = \frac{1}{N} \sum_{i=1}^N \left| \mathbf{f}(\mathbf{s}_i) - \mathbf{F}_i \right|^2$$

where $\mathbf{f}(\mathbf{s}) = -\nabla_{\mathbf{s}} U(\mathbf{s})$ is the predicted force.

Implementation Details

Classes

  1. RBFPotential (nn.Module): Core RBF potential implementation

    • Handles periodic/non-periodic distance computation
    • Computes potential and gradients efficiently
  2. GradMatch (pl.LightningModule): PyTorch Lightning module for training

    • Implements training and validation steps
    • Configures optimizer (AdamW)
  3. ForceMatchingRegressor: High-level API class

    • Handles data preprocessing and splitting
    • Manages training with callbacks
    • Provides fit(), predict(), and get_parameters() methods

Key Implementation Choices

  • Separation of Concerns: Training logic (GradMatch) separated from API (ForceMatchingRegressor)
  • Automatic Device Handling: Input tensors automatically moved to model device
  • Temporary Checkpoints: Uses context manager for automatic cleanup
  • Bounds-Aware Initialization: Centers initialized within actual variable bounds

Testing

Comprehensive unit tests (test_regression.py) with 24 test cases covering:

  • RBF potential initialization and computation
  • Periodic and non-periodic distance functions
  • Mixed periodic/non-periodic variables
  • Training and validation steps
  • Model checkpointing and early stopping
  • Prediction and parameter extraction
  • Error handling

Coverage: 100% for regression.py (113 statements)

Breaking Changes

None - this is a new feature addition.

Dependencies

  • pytorch (PyTorch, via conda)
  • lightning (PyTorch Lightning, via conda)
  • numpy

Files Changed

  • openxps/regression.py (new file, 331 lines)
  • openxps/tests/test_regression.py (new file, 384 lines)
  • openxps/__init__.py (updated to export ForceMatchingRegressor)
  • openxps/bounds/base.py (added length attribute to Bounds class)
  • pyproject.toml (added torch and lightning to dependencies)
  • devtools/conda-envs/test_env.yaml (added pytorch and lightning dependencies)
  • devtools/conda-envs/deployment_env.yaml (added pytorch and lightning dependencies)
  • devtools/conda-recipes/anaconda/meta.yaml (added pytorch and lightning to runtime dependencies)

Future Enhancements

Potential future improvements:

  • Support for different distance metrics
  • Custom kernel functions
  • Hyperparameter optimization
  • Model persistence/loading

@craabreu craabreu merged commit 5d52284 into main Nov 13, 2025
9 of 10 checks passed
@craabreu craabreu deleted the rbf-regression branch November 13, 2025 02:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant