Skip to content

Commit e65e88a

Browse files
committed
fix: update tests to remove statistical error
1 parent a0338dd commit e65e88a

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

tests/test_ssm.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def read_lung_data():
2626
class TestSSM(unittest.TestCase):
2727
def test_morph_model(self):
2828
num_repititions = 10
29-
for _ in range(0, num_repititions):
29+
for test_sample_id in range(0, num_repititions):
3030
landmark_coordinates = read_lung_data()
3131

3232
ssm_obj = pyssam.SSM(landmark_coordinates)
3333
ssm_obj.create_pca_model(ssm_obj.landmarks_columns_scale)
3434
mean_shape_columnvector = ssm_obj.compute_dataset_mean()
3535
mean_shape = mean_shape_columnvector.reshape(-1, 3)
3636

37-
test_sample_id = np.random.randint(0, len(landmark_coordinates))
37+
# test_sample_id = np.random.randint(0, len(landmark_coordinates))
3838
test_shape_columnvec = (
3939
landmark_coordinates[test_sample_id] - landmark_coordinates[test_sample_id].mean(axis=0)
4040
).reshape(-1)
@@ -64,7 +64,7 @@ def test_morph_model(self):
6464

6565
def test_morph_model_reduced_dimension(self):
6666
num_repititions = 10
67-
for _ in range(0, num_repititions):
67+
for test_sample_id in range(0, num_repititions):
6868
landmark_coordinates = read_lung_data()
6969

7070
ssm_obj = pyssam.SSM(landmark_coordinates)
@@ -73,7 +73,7 @@ def test_morph_model_reduced_dimension(self):
7373
mean_shape_columnvector = ssm_obj.compute_dataset_mean()
7474
mean_shape = mean_shape_columnvector.reshape(-1, 3)
7575

76-
test_sample_id = np.random.randint(0, len(landmark_coordinates))
76+
# test_sample_id = np.random.randint(0, len(landmark_coordinates))
7777
test_shape_columnvec = (
7878
landmark_coordinates[test_sample_id] - landmark_coordinates[test_sample_id].mean(axis=0)
7979
).reshape(-1)
@@ -100,13 +100,12 @@ def test_morph_model_reduced_dimension(self):
100100

101101
def test_fit_model_parameters_all_modes(self):
102102
num_repititions = 10
103-
for _ in range(0, num_repititions):
103+
for test_sample_id in range(0, num_repititions):
104104
landmark_coordinates = read_lung_data()
105105

106106
ssm_obj = pyssam.SSM(landmark_coordinates)
107107
ssm_obj.create_pca_model(ssm_obj.landmarks_columns_scale, desired_variance=0.7)
108108

109-
test_sample_id = np.random.randint(0, len(landmark_coordinates))
110109
target_shape = ssm_obj.landmarks_columns_scale[test_sample_id]
111110
model_parameters = ssm_obj.fit_model_parameters(target_shape, ssm_obj.pca_model_components)
112111
model_parameters = np.where(model_parameters < 5, model_parameters, 3)
@@ -115,17 +114,16 @@ def test_fit_model_parameters_all_modes(self):
115114
dataset_mean = ssm_obj.compute_dataset_mean()
116115
morphed_shape = ssm_obj.morph_model(dataset_mean, ssm_obj.pca_model_components, model_parameters)
117116
error = abs(target_shape - morphed_shape)
118-
assert np.isclose(error.mean(), 0), f"error is non-zero ({error.mean()})"
117+
assert np.isclose(error.mean(), 0), f"error is non-zero ({error.mean()}) sample {test_sample_id}"
119118

120119
def test_fit_model_parameters_reduced_modes(self):
121120
num_repititions = 10
122-
for _ in range(0, num_repititions):
121+
for test_sample_id in range(0, num_repititions):
123122
landmark_coordinates = read_lung_data()
124123

125124
ssm_obj = pyssam.SSM(landmark_coordinates)
126125
ssm_obj.create_pca_model(ssm_obj.landmarks_columns_scale, desired_variance=0.7)
127126

128-
test_sample_id = np.random.randint(0, len(landmark_coordinates))
129127
target_shape = ssm_obj.landmarks_columns_scale[test_sample_id]
130128
model_parameters = ssm_obj.fit_model_parameters(target_shape, ssm_obj.pca_model_components, num_modes=2)
131129
dataset_mean = ssm_obj.compute_dataset_mean()

0 commit comments

Comments
 (0)