Skip to content

Commit

Permalink
Correct tests
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Aug 6, 2024
1 parent fb3aea6 commit 96ff8ba
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/unit_tests/test_hydryss_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from qusi.internal.hadryss_model import Hadryss, HadryssBinaryClassEndModule, \
HadryssMultiClassEndModule
HadryssMultiClassScoreEndModule


def test_lengths_give_correct_output_size():
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_binary_classification_end_module_produces_expected_shape():


def test_multi_class_classification_end_module_produces_expected_shape():
model = Hadryss.new(input_length=100, end_module=HadryssMultiClassEndModule.new(number_of_classes=3))
model = Hadryss.new(input_length=100, end_module=HadryssMultiClassScoreEndModule.new(number_of_classes=3))

output = model(torch.arange(7 * 100, dtype=torch.float32).reshape([7, 100]))

Expand Down

0 comments on commit 96ff8ba

Please sign in to comment.