Skip to content

Commit afafd43

Browse files
authored
fix: index tensor on cpu (#961)
### Summary of Changes The index tensor of an `ImageDataset` sometimes ended up on the CPU, instead of the default device, which led to runtime errors. This PR fixes this.
1 parent 5b32acc commit afafd43

File tree

7 files changed

+29
-38
lines changed

7 files changed

+29
-38
lines changed

docs/tutorials/convolutional_neural_network_for_image_classification.ipynb

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
},
4848
{
4949
"cell_type": "code",
50-
"execution_count": null,
5150
"id": "initial_id",
5251
"metadata": {
5352
"collapsed": true
@@ -57,7 +56,8 @@
5756
"\n",
5857
"images, filepaths = ImageList.from_files(\"data/shapes\", return_filenames=True)"
5958
],
60-
"outputs": []
59+
"outputs": [],
60+
"execution_count": null
6161
},
6262
{
6363
"cell_type": "markdown",
@@ -84,8 +84,8 @@
8484
"collapsed": false
8585
},
8686
"id": "66dcf95a3fa51f23",
87-
"execution_count": null,
88-
"outputs": []
87+
"outputs": [],
88+
"execution_count": null
8989
},
9090
{
9191
"cell_type": "markdown",
@@ -108,8 +108,8 @@
108108
"collapsed": false
109109
},
110110
"id": "32056ddf5396e070",
111-
"execution_count": null,
112-
"outputs": []
111+
"outputs": [],
112+
"execution_count": null
113113
},
114114
{
115115
"cell_type": "markdown",
@@ -149,8 +149,8 @@
149149
"collapsed": false
150150
},
151151
"id": "806a8091249d533a",
152-
"execution_count": null,
153-
"outputs": []
152+
"outputs": [],
153+
"execution_count": null
154154
},
155155
{
156156
"cell_type": "markdown",
@@ -175,8 +175,8 @@
175175
"collapsed": false
176176
},
177177
"id": "af68cc0d32655d32",
178-
"execution_count": null,
179-
"outputs": []
178+
"outputs": [],
179+
"execution_count": null
180180
},
181181
{
182182
"cell_type": "markdown",
@@ -198,15 +198,13 @@
198198
},
199199
{
200200
"cell_type": "code",
201-
"source": [
202-
"cnn_fitted = cnn.fit(dataset, epoch_size=32, batch_size=16)"
203-
],
201+
"source": "cnn_fitted = cnn.fit(dataset, epoch_size=8, batch_size=16)",
204202
"metadata": {
205203
"collapsed": false
206204
},
207205
"id": "381627a94d500675",
208-
"execution_count": null,
209-
"outputs": []
206+
"outputs": [],
207+
"execution_count": null
210208
},
211209
{
212210
"cell_type": "markdown",
@@ -227,8 +225,8 @@
227225
"collapsed": false
228226
},
229227
"id": "62f63dd68362c8b7",
230-
"execution_count": null,
231-
"outputs": []
228+
"outputs": [],
229+
"execution_count": null
232230
},
233231
{
234232
"cell_type": "markdown",
@@ -249,8 +247,8 @@
249247
"collapsed": false
250248
},
251249
"id": "779277d73e30554d",
252-
"execution_count": null,
253-
"outputs": []
250+
"outputs": [],
251+
"execution_count": null
254252
},
255253
{
256254
"cell_type": "markdown",
@@ -271,8 +269,8 @@
271269
"collapsed": false
272270
},
273271
"id": "a5ddbbfba41aa7f",
274-
"execution_count": null,
275-
"outputs": []
272+
"outputs": [],
273+
"execution_count": null
276274
},
277275
{
278276
"cell_type": "markdown",
@@ -293,8 +291,8 @@
293291
"collapsed": false
294292
},
295293
"id": "7081595d7100fb42",
296-
"execution_count": null,
297-
"outputs": []
294+
"outputs": [],
295+
"execution_count": null
298296
}
299297
],
300298
"metadata": {

src/safeds/data/labeled/containers/_image_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ def split(
356356
upper_bound=_ClosedBound(1),
357357
)
358358

359+
_init_default_device()
360+
359361
first_dataset: ImageDataset[Out_co] = copy.copy(self)
360362
second_dataset: ImageDataset[Out_co] = copy.copy(self)
361363

src/safeds/ml/nn/converters/_input_converter_image_to_column.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList
77
from safeds.data.labeled.containers import ImageDataset
88
from safeds.data.labeled.containers._image_dataset import _ColumnAsTensor
9-
from safeds.data.tabular.containers import Column
109

1110
from ._input_converter_image import _InputConversionImage
1211

1312
if TYPE_CHECKING:
1413
from torch import Tensor
1514

1615
from safeds.data.image.containers import ImageList
16+
from safeds.data.tabular.containers import Column
1717
from safeds.data.tabular.transformation import OneHotEncoder
1818

1919

@@ -43,9 +43,9 @@ def _data_conversion_output(
4343
output = torch.zeros(len(input_data), len(one_hot_encoder._get_names_of_added_columns()))
4444
output[torch.arange(len(input_data)), output_data] = 1
4545

46-
im_dataset: ImageDataset[Column] = ImageDataset[Column].__new__(ImageDataset)
46+
im_dataset: ImageDataset[Column] = object.__new__(ImageDataset)
4747
im_dataset._output = _ColumnAsTensor._from_tensor(output, column_name, one_hot_encoder)
48-
im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data))))
48+
im_dataset._shuffle_tensor_indices = torch.arange(len(input_data))
4949
im_dataset._shuffle_after_epoch = False
5050
im_dataset._batch_size = 1
5151
im_dataset._next_batch_index = 0

src/safeds/ml/nn/converters/_input_converter_image_to_table.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList
77
from safeds.data.labeled.containers import ImageDataset
88
from safeds.data.labeled.containers._image_dataset import _TableAsTensor
9-
from safeds.data.tabular.containers import Table
109

1110
from ._input_converter_image import _InputConversionImage
1211

1312
if TYPE_CHECKING:
1413
from torch import Tensor
1514

1615
from safeds.data.image.containers import ImageList
16+
from safeds.data.tabular.containers import Table
1717

1818

1919
class InputConversionImageToTable(_InputConversionImage):
@@ -33,9 +33,9 @@ def _data_conversion_output(self, input_data: ImageList, output_data: Tensor) ->
3333
output = torch.zeros(len(input_data), len(column_names))
3434
output[torch.arange(len(input_data)), output_data] = 1
3535

36-
im_dataset: ImageDataset[Table] = ImageDataset[Table].__new__(ImageDataset)
36+
im_dataset: ImageDataset[Table] = object.__new__(ImageDataset)
3737
im_dataset._output = _TableAsTensor._from_tensor(output, column_names)
38-
im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data))))
38+
im_dataset._shuffle_tensor_indices = torch.arange(len(input_data))
3939
im_dataset._shuffle_after_epoch = False
4040
im_dataset._batch_size = 1
4141
im_dataset._next_batch_index = 0

src/src/resources/to_csv_file.csv

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/src/resources/to_json_file.json

Lines changed: 0 additions & 5 deletions
This file was deleted.
-1.8 KB
Binary file not shown.

0 commit comments

Comments
 (0)