From e994419598ccdd76c4918cdceb6551d9fbf17388 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 15 Jan 2025 17:56:34 +0100 Subject: [PATCH 1/2] dynamic dtype --- src/autora/theorist/darts/regressor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/autora/theorist/darts/regressor.py b/src/autora/theorist/darts/regressor.py index e5c3c01..181b791 100644 --- a/src/autora/theorist/darts/regressor.py +++ b/src/autora/theorist/darts/regressor.py @@ -628,8 +628,9 @@ def predict(self, X: np.ndarray) -> np.ndarray: # ensures the self.model_ parameter is initialized and otherwise throws an error, # so we check that explicitly here and pass the model which can't be None. assert self.model_ is not None - - y_ = self.model_(torch.as_tensor(X_).float()) + + dtype = next(self.model_.parameters()).dtype + y_ = self.model_(torch.as_tensor(X_).to(dtype)) y = y_.detach().numpy() return y From d0a39d67a30e61fd7f6bb55f84820514b2307ee6 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 15 Jan 2025 18:13:31 +0100 Subject: [PATCH 2/2] test commit --- src/autora/theorist/darts/regressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/autora/theorist/darts/regressor.py b/src/autora/theorist/darts/regressor.py index 181b791..8b15823 100644 --- a/src/autora/theorist/darts/regressor.py +++ b/src/autora/theorist/darts/regressor.py @@ -628,7 +628,7 @@ def predict(self, X: np.ndarray) -> np.ndarray: # ensures the self.model_ parameter is initialized and otherwise throws an error, # so we check that explicitly here and pass the model which can't be None. assert self.model_ is not None - + dtype = next(self.model_.parameters()).dtype y_ = self.model_(torch.as_tensor(X_).to(dtype)) y = y_.detach().numpy()