Skip to content

Commit 9e40b65

Browse files
Gerhardsa0megalinter-botlars-reimann
authored
fix: Conversion of tabular dataset to tensors (#757)
### Summary of Changes Fixed conversion of tabular dataset to tensors and associated tests. --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Co-authored-by: Lars Reimann <mail@larsreimann.com>
1 parent 92622fb commit 9e40b65

File tree

5 files changed

+14
-15
lines changed

5 files changed

+14
-15
lines changed

src/safeds/data/labeled/containers/_image_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,10 @@ def __init__(self, table: Table) -> None:
294294
_init_default_device()
295295

296296
self._column_names = table.column_names
297-
self._tensor = torch.Tensor(table._data_frame.to_torch()).to(_get_device())
297+
if table.number_of_rows == 0:
298+
self._tensor = torch.empty((0, table.number_of_columns), dtype=torch.float32).to(_get_device())
299+
else:
300+
self._tensor = table._data_frame.to_torch().to(_get_device())
298301

299302
if not torch.all(self._tensor.sum(dim=1) == torch.ones(self._tensor.size(dim=0))):
300303
raise ValueError(

src/safeds/data/labeled/containers/_time_series_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def _create_dataset(features: torch.Tensor, target: torch.Tensor) -> Dataset:
322322

323323
class _CustomDataset(Dataset):
324324
def __init__(self, features_dataset: torch.Tensor, target_dataset: torch.Tensor):
325-
self.X = features_dataset
326-
self.Y = target_dataset.unsqueeze(-1)
325+
self.X = features_dataset.float()
326+
self.Y = target_dataset.unsqueeze(-1).float()
327327
self.len = self.X.shape[0]
328328

329329
def __getitem__(self, item: int) -> tuple[torch.Tensor, torch.Tensor]:
@@ -341,8 +341,8 @@ def _create_dataset_predict(features: torch.Tensor) -> Dataset:
341341
_init_default_device()
342342

343343
class _CustomDataset(Dataset):
344-
def __init__(self, features: torch.Tensor):
345-
self.X = features
344+
def __init__(self, datas: torch.Tensor):
345+
self.X = datas.float()
346346
self.len = self.X.shape[0]
347347

348348
def __getitem__(self, item: int) -> torch.Tensor:

src/safeds/ml/nn/_output_conversion_time_series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _data_conversion(self, input_data: TimeSeriesDataset, output_data: Tensor, *
7878
window_size: int = kwargs["window_size"]
7979
forecast_horizon: int = kwargs["forecast_horizon"]
8080
input_data_table = input_data.to_table()
81-
input_data_table = input_data_table.slice_rows(window_size + forecast_horizon)
81+
input_data_table = input_data_table.slice_rows(start=window_size + forecast_horizon)
8282

8383
return input_data_table.add_columns(
8484
[Column(self._prediction_name, output_data.tolist())],

tests/safeds/data/labeled/containers/_time_series_dataset/test_into_dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_should_create_dataloader_invalid(
175175
1,
176176
0,
177177
OutOfBoundsError,
178-
r"forecast_horizon \(=0\) is not inside \[1, \u221e\).",
178+
None,
179179
),
180180
(
181181
Table(
@@ -189,7 +189,7 @@ def test_should_create_dataloader_invalid(
189189
0,
190190
1,
191191
OutOfBoundsError,
192-
r"window_size \(=0\) is not inside \[1, \u221e\).",
192+
None,
193193
),
194194
],
195195
ids=[
@@ -204,7 +204,7 @@ def test_should_create_dataloader_predict_invalid(
204204
window_size: int,
205205
forecast_horizon: int,
206206
error_type: type[ValueError],
207-
error_msg: str,
207+
error_msg: str | None,
208208
device: Device,
209209
) -> None:
210210
configure_test_with_device(device)

tests/safeds/ml/nn/test_forward_workflow.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,8 @@ def test_forward_model(device: Device) -> None:
2323
path=resolve_resource_path(_inflation_path),
2424
)
2525
table_1 = table_1.remove_columns(["date"])
26-
table_2 = table_1.slice_rows(length=table_1.number_of_rows - 14)
27-
table_2 = table_2.add_columns(
28-
[
29-
table_1.slice_rows(start=14).get_column("value").rename("target"),
30-
]
31-
)
26+
table_2 = table_1.slice_rows(start=0, length=table_1.number_of_rows - 14)
27+
table_2 = table_2.add_columns([(table_1.slice_rows(start=14)).get_column("value").rename("target")])
3228
train_table, test_table = table_2.split_rows(0.8)
3329

3430
ss = StandardScaler()

0 commit comments

Comments
 (0)