Skip to content

Commit

Permalink
fix: update tests to remove statistical error
Browse files Browse the repository at this point in the history
  • Loading branch information
jvwilliams23 committed Jan 22, 2023
1 parent a0338dd commit e65e88a
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions tests/test_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def read_lung_data():
class TestSSM(unittest.TestCase):
def test_morph_model(self):
num_repititions = 10
for _ in range(0, num_repititions):
for test_sample_id in range(0, num_repititions):
landmark_coordinates = read_lung_data()

ssm_obj = pyssam.SSM(landmark_coordinates)
ssm_obj.create_pca_model(ssm_obj.landmarks_columns_scale)
mean_shape_columnvector = ssm_obj.compute_dataset_mean()
mean_shape = mean_shape_columnvector.reshape(-1, 3)

test_sample_id = np.random.randint(0, len(landmark_coordinates))
# test_sample_id = np.random.randint(0, len(landmark_coordinates))
test_shape_columnvec = (
landmark_coordinates[test_sample_id] - landmark_coordinates[test_sample_id].mean(axis=0)
).reshape(-1)
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_morph_model(self):

def test_morph_model_reduced_dimension(self):
num_repititions = 10
for _ in range(0, num_repititions):
for test_sample_id in range(0, num_repititions):
landmark_coordinates = read_lung_data()

ssm_obj = pyssam.SSM(landmark_coordinates)
Expand All @@ -73,7 +73,7 @@ def test_morph_model_reduced_dimension(self):
mean_shape_columnvector = ssm_obj.compute_dataset_mean()
mean_shape = mean_shape_columnvector.reshape(-1, 3)

test_sample_id = np.random.randint(0, len(landmark_coordinates))
# test_sample_id = np.random.randint(0, len(landmark_coordinates))
test_shape_columnvec = (
landmark_coordinates[test_sample_id] - landmark_coordinates[test_sample_id].mean(axis=0)
).reshape(-1)
Expand All @@ -100,13 +100,12 @@ def test_morph_model_reduced_dimension(self):

def test_fit_model_parameters_all_modes(self):
num_repititions = 10
for _ in range(0, num_repititions):
for test_sample_id in range(0, num_repititions):
landmark_coordinates = read_lung_data()

ssm_obj = pyssam.SSM(landmark_coordinates)
ssm_obj.create_pca_model(ssm_obj.landmarks_columns_scale, desired_variance=0.7)

test_sample_id = np.random.randint(0, len(landmark_coordinates))
target_shape = ssm_obj.landmarks_columns_scale[test_sample_id]
model_parameters = ssm_obj.fit_model_parameters(target_shape, ssm_obj.pca_model_components)
model_parameters = np.where(model_parameters < 5, model_parameters, 3)
Expand All @@ -115,17 +114,16 @@ def test_fit_model_parameters_all_modes(self):
dataset_mean = ssm_obj.compute_dataset_mean()
morphed_shape = ssm_obj.morph_model(dataset_mean, ssm_obj.pca_model_components, model_parameters)
error = abs(target_shape - morphed_shape)
assert np.isclose(error.mean(), 0), f"error is non-zero ({error.mean()})"
assert np.isclose(error.mean(), 0), f"error is non-zero ({error.mean()}) sample {test_sample_id}"

def test_fit_model_parameters_reduced_modes(self):
num_repititions = 10
for _ in range(0, num_repititions):
for test_sample_id in range(0, num_repititions):
landmark_coordinates = read_lung_data()

ssm_obj = pyssam.SSM(landmark_coordinates)
ssm_obj.create_pca_model(ssm_obj.landmarks_columns_scale, desired_variance=0.7)

test_sample_id = np.random.randint(0, len(landmark_coordinates))
target_shape = ssm_obj.landmarks_columns_scale[test_sample_id]
model_parameters = ssm_obj.fit_model_parameters(target_shape, ssm_obj.pca_model_components, num_modes=2)
dataset_mean = ssm_obj.compute_dataset_mean()
Expand Down

0 comments on commit e65e88a

Please sign in to comment.