Skip to content

Commit 8bfb672

Browse files
committed
fix empty y label torch error
1 parent 13fbe3f commit 8bfb672

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/asleep/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def predict(self, X, groups=None):
193193
model.to(self.device)
194194

195195
_, y_pred, _ = sslmodel.predict(
196-
model, dataloader, self.device, output_logits=False)
196+
model, dataloader, self.device, output_logits=False, name='prediction')
197197

198198
y_pred = self.hmms.predict(y_pred, groups=groups)
199199

src/asleep/sslmodel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def get_sslnet(tag='v1.0.0', pretrained=False):
115115
return sslnet
116116

117117

118-
def predict(model, data_loader, device, output_logits=False):
118+
def predict(model, data_loader, device,
119+
output_logits=False, name='train'):
119120
"""
120121
Iterate over the dataloader and do prediction with a pytorch model.
121122
:param nn.Module model: pytorch Module
@@ -145,8 +146,12 @@ def predict(model, data_loader, device, output_logits=False):
145146
pred_y = torch.argmax(logits, dim=1)
146147
predictions_list.append(pred_y.cpu())
147148
pid_list.extend(pid)
148-
true_list = torch.cat(true_list)
149+
149150
predictions_list = torch.cat(predictions_list)
151+
if name == 'prediction':
152+
true_list = predictions_list
153+
else:
154+
true_list = torch.Tensor([1, 2, 3])
150155

151156
if output_logits:
152157
return (

0 commit comments

Comments
 (0)