diff --git a/tsts/apis.py b/tsts/apis.py index 54f2557..b640598 100644 --- a/tsts/apis.py +++ b/tsts/apis.py @@ -35,11 +35,11 @@ def load_sample( df = df.fillna(0.0) # Take only the values of input & output variables if len(in_feats) > 0: - df = df[in_feats] + df_in = df[in_feats] if len(out_feats) > 0: - df = df[out_feats] - X = torch.tensor(df.values, dtype=torch.float32) - y = torch.tensor(df.values, dtype=torch.float32) + df_out = df[out_feats] + X = torch.tensor(df_in.values, dtype=torch.float32) + y = torch.tensor(df_out.values, dtype=torch.float32) return (X, y)