Skip to content

Commit 2235440

Browse files
committed
Move get_device to a separate location
1 parent d8eff4c commit 2235440

File tree

5 files changed

+15
-18
lines changed

5 files changed

+15
-18
lines changed

examples/transit_infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset
77
from qusi.hadryss_model import Hadryss
8-
from qusi.infer_session import get_device, infer_session
8+
from qusi.infer_session import infer_session
9+
from qusi.device import get_device
910
from qusi.light_curve_collection import LightCurveCollection
1011
from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve
1112

examples/transit_infinite_dataset_test.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchmetrics.classification import BinaryAccuracy
99

1010
from qusi.hadryss_model import Hadryss
11+
from qusi.device import get_device
1112
from qusi.light_curve_collection import LabeledLightCurveCollection
1213
from qusi.light_curve_dataset import LightCurveDataset
1314
from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve
@@ -71,14 +72,6 @@ def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model
7172
return results
7273

7374

74-
def get_device():
75-
if torch.cuda.is_available():
76-
device = torch.device('cuda')
77-
else:
78-
device = torch.device('cpu')
79-
return device
80-
81-
8275
def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int):
8376
batch_count = 0
8477
metric_totals = torch.zeros(size=[len(metric_functions)])

src/qusi/device.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
from torch.types import Device
3+
4+
5+
def get_device() -> Device:
6+
if torch.cuda.is_available():
7+
device = torch.device("cuda")
8+
else:
9+
device = torch.device("cpu")
10+
return device

src/qusi/infer_session.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@ def infer_session(
2222
return results
2323

2424

25-
def get_device() -> Device:
26-
if torch.cuda.is_available():
27-
device = torch.device("cuda")
28-
else:
29-
device = torch.device("cpu")
30-
return device
31-
32-
3325
def infer_phase(dataloader, model: Module, device: Device):
3426
batch_count = 0
3527
batches_of_predicted_targets = []

tests/end_to_end_tests/test_toy_infer_session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import numpy as np
55

6-
from qusi.infer_session import get_device, infer_session
6+
from qusi.infer_session import infer_session
7+
from qusi.device import get_device
78
from qusi.light_curve_dataset import (
89
default_light_curve_post_injection_transform,
910
)

0 commit comments

Comments
 (0)