diff --git a/kliff/loss.py b/kliff/loss.py index 0156fcb9..b384e0f3 100644 --- a/kliff/loss.py +++ b/kliff/loss.py @@ -36,8 +36,8 @@ def energy_forces_residual( identifier: str, natoms: int, weight: float, - prediction: Union[np.array, torch.Tensor], - reference: Union[np.array, torch.Tensor], + prediction: np.array, + reference: np.array, data: Dict[str, Any], ): """ @@ -109,8 +109,8 @@ def energy_residual( identifier: str, natoms: int, weight: float, - prediction: Union[np.array, torch.Tensor], - reference: Union[np.array, torch.Tensor], + prediction: np.array, + reference: np.array, data: Dict[str, Any], ): """ @@ -129,8 +129,8 @@ def forces_residual( identifier: str, natoms: int, weight: float, - prediction: Union[np.array, torch.Tensor], - reference: Union[np.array, torch.Tensor], + prediction: np.array, + reference: np.array, data: Dict[str, Any], ): """