Skip to content

Commit 2f57f56

Browse files
committed
First draft trainer framework
1 parent 86ab5d2 commit 2f57f56

File tree

11 files changed

+1163
-181
lines changed

11 files changed

+1163
-181
lines changed

kliff/_exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
This module contains exceptions to be raised in kliff modules, along with details on
3+
where they are raised.
4+
"""
5+
6+
7+
class TrainerError(Exception):
8+
"""
9+
Exceptions to be raised in Trainer and associated classes.
10+
"""
11+
12+
def __init__(self, message):
13+
super().__init__(message)

kliff/dataset/dataset.py

Lines changed: 203 additions & 26 deletions
Large diffs are not rendered by default.

kliff/trainer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .kim_trainer import KIMTrainer
12
from .kliff_trainer import Trainer

kliff/trainer/kim_residuals.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Any, Dict
2+
3+
import numpy as np
4+
5+
6+
def MSE_residuals(
7+
predictions: np.ndarray,
8+
targets: np.ndarray,
9+
) -> np.ndarray:
10+
r"""
11+
Compute the mean squared error (MSE) of the residuals.
12+
13+
Args:
14+
15+
Returns:
16+
The MSE of the residuals.
17+
"""
18+
residuals = predictions - targets
19+
return np.mean(residuals**2)

0 commit comments

Comments
 (0)