From 96ff8ba1623af9278bf05da75eeb7f781a01682e Mon Sep 17 00:00:00 2001 From: golmschenk Date: Tue, 6 Aug 2024 13:22:58 -0400 Subject: [PATCH] Correct tests --- tests/unit_tests/test_hydryss_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_hydryss_model.py b/tests/unit_tests/test_hydryss_model.py index a293747..1b4f3c8 100644 --- a/tests/unit_tests/test_hydryss_model.py +++ b/tests/unit_tests/test_hydryss_model.py @@ -1,7 +1,7 @@ import torch from qusi.internal.hadryss_model import Hadryss, HadryssBinaryClassEndModule, \ - HadryssMultiClassEndModule + HadryssMultiClassScoreEndModule def test_lengths_give_correct_output_size(): @@ -41,7 +41,7 @@ def test_binary_classification_end_module_produces_expected_shape(): def test_multi_class_classification_end_module_produces_expected_shape(): - model = Hadryss.new(input_length=100, end_module=HadryssMultiClassEndModule.new(number_of_classes=3)) + model = Hadryss.new(input_length=100, end_module=HadryssMultiClassScoreEndModule.new(number_of_classes=3)) output = model(torch.arange(7 * 100, dtype=torch.float32).reshape([7, 100]))