Skip to content

Commit e72339e

Browse files
fix: fixed devices with new polars implementation (#756)
### Summary of Changes fix: fixed devices with new polars implementation refactor: using `_utils._get_random_seed` in `ImageList.from_files` to shuffle all files --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
1 parent a7d92ae commit e72339e

13 files changed

+19
-16
lines changed

src/safeds/data/image/containers/_image_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import TYPE_CHECKING, Literal, overload
1111

1212
from safeds._config import _init_default_device
13+
from safeds._utils import _get_random_seed
1314
from safeds.data.image.containers._image import Image
1415
from safeds.exceptions import OutOfBoundsError, ClosedBound
1516

@@ -176,6 +177,8 @@ def from_files(
176177

177178
_init_default_device()
178179

180+
random.seed(_get_random_seed())
181+
179182
from safeds.data.image.containers._empty_image_list import _EmptyImageList
180183
from safeds.data.image.containers._multi_size_image_list import _MultiSizeImageList
181184
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList

src/safeds/data/labeled/containers/_image_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from typing import TYPE_CHECKING, Generic, TypeVar
77

8-
from safeds._config import _init_default_device
8+
from safeds._config import _get_device, _init_default_device
99
from safeds._utils import _structural_hash
1010
from safeds.data.image.containers import ImageList
1111
from safeds.data.image.containers._empty_image_list import _EmptyImageList
@@ -294,7 +294,7 @@ def __init__(self, table: Table) -> None:
294294
_init_default_device()
295295

296296
self._column_names = table.column_names
297-
self._tensor = torch.Tensor(table._data_frame.to_numpy()).to(torch.get_default_device())
297+
self._tensor = torch.Tensor(table._data_frame.to_torch()).to(_get_device())
298298

299299
if not torch.all(self._tensor.sum(dim=1) == torch.ones(self._tensor.size(dim=0))):
300300
raise ValueError(
@@ -355,8 +355,8 @@ def __init__(self, column: Column) -> None:
355355
category=UserWarning,
356356
)
357357
self._one_hot_encoder = OneHotEncoder().fit(column_as_table, [self._column_name])
358-
self._tensor = torch.Tensor(self._one_hot_encoder.transform(column_as_table)._data_frame.to_numpy()).to(
359-
torch.get_default_device(),
358+
self._tensor = torch.Tensor(self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch()).to(
359+
_get_device(),
360360
)
361361

362362
def __eq__(self, other: object) -> bool:

src/safeds/data/labeled/containers/_time_series_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from typing import TYPE_CHECKING, Any
55

6-
from safeds._config import _init_default_device
6+
from safeds._config import _get_device, _init_default_device
77
from safeds._utils import _structural_hash
88
from safeds.exceptions import ClosedBound, OutOfBoundsError
99

@@ -235,7 +235,7 @@ def _into_dataloader_with_window(self, window_size: int, forecast_horizon: int,
235235
label = target_tensor[i + window_size + forecast_horizon]
236236
for col in feature_cols:
237237
data = torch.tensor(col._series.to_numpy(), dtype=torch.float32)
238-
window = torch.cat((window, data[i: i + window_size]), dim=0)
238+
window = torch.cat((window, data[i : i + window_size]), dim=0)
239239
x_s.append(window)
240240
y_s.append(label)
241241
x_s_tensor = torch.stack(x_s)
@@ -279,7 +279,7 @@ def _into_dataloader_with_window_predict(
279279

280280
_init_default_device()
281281

282-
target_tensor = self.target._series.to_torch()
282+
target_tensor = self.target._series.to_torch().to(_get_device())
283283
x_s = []
284284

285285
size = target_tensor.size(0)

src/safeds/data/tabular/containers/_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1994,7 +1994,7 @@ def _into_dataloader(self, batch_size: int) -> DataLoader:
19941994
_init_default_device()
19951995

19961996
return DataLoader(
1997-
dataset=_create_dataset(self._data_frame.to_torch(dtype=pl.Float32)),
1997+
dataset=_create_dataset(self._data_frame.to_torch(dtype=pl.Float32).to(_get_device())),
19981998
batch_size=batch_size,
19991999
generator=torch.Generator(device=_get_device()),
20002000
)

tests/safeds/ml/nn/test_cnn_workflow.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,22 @@ class TestImageToTableClassifier:
3838
(
3939
1234,
4040
device_cuda,
41-
["grayscale"] * 7,
41+
["white_square"] * 7,
4242
),
4343
(
4444
4711,
4545
device_cuda,
46-
["white_square"] * 7,
46+
["rgba"] * 7,
4747
),
4848
(
4949
1234,
5050
device_cpu,
51-
["grayscale"] * 7,
51+
["white_square"] * 7,
5252
),
5353
(
5454
4711,
5555
device_cpu,
56-
["white_square"] * 7,
56+
["rgba"] * 7,
5757
),
5858
],
5959
ids=["seed-1234-cuda", "seed-4711-cuda", "seed-1234-cpu", "seed-4711-cpu"],
@@ -106,22 +106,22 @@ class TestImageToColumnClassifier:
106106
(
107107
1234,
108108
device_cuda,
109-
["grayscale"] * 7,
109+
["white_square"] * 7,
110110
),
111111
(
112112
4711,
113113
device_cuda,
114-
["white_square"] * 7,
114+
["rgba"] * 7,
115115
),
116116
(
117117
1234,
118118
device_cpu,
119-
["grayscale"] * 7,
119+
["white_square"] * 7,
120120
),
121121
(
122122
4711,
123123
device_cpu,
124-
["white_square"] * 7,
124+
["rgba"] * 7,
125125
),
126126
],
127127
ids=["seed-1234-cuda", "seed-4711-cuda", "seed-1234-cpu", "seed-4711-cpu"],

0 commit comments

Comments
 (0)