Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
timodonnell committed Mar 14, 2024
1 parent db677b9 commit 9eea537
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
1 change: 0 additions & 1 deletion test/test_released_predictors_on_hpv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,3 @@ def test_on_hpv(predictors, df=DF):
mean_scores = scores_df.mean()
assert_greater(mean_scores["allele-specific"], mean_scores["netmhcpan4"])
assert_greater(mean_scores["pan-allele"], mean_scores["netmhcpan4"])
return scores_df
8 changes: 5 additions & 3 deletions test/test_released_predictors_well_correlated.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_correlation(
alleles=None,
num_peptides_per_length=1000,
lengths=[8, 9, 10],
debug=False):
debug=False,
return_result=False):
peptides = []
for length in lengths:
peptides.extend(random_peptides(num_peptides_per_length, length))
Expand Down Expand Up @@ -77,7 +78,8 @@ def test_correlation(
print("Mean correlation", results_df.correlation.mean())
assert_greater(results_df.correlation.mean(), 0.65)

return results_df
if return_result:
return results_df


parser = argparse.ArgumentParser(usage=__doc__)
Expand All @@ -91,7 +93,7 @@ def test_correlation(
# If run directly from python, leave the user in a shell to explore results.
startup()
args = parser.parse_args(sys.argv[1:])
result = test_correlation(alleles=args.alleles, debug=True)
result = test_correlation(alleles=args.alleles, debug=True, return_result=True)

# Leave in ipython
import ipdb # pylint: disable=import-error
Expand Down
4 changes: 2 additions & 2 deletions test/test_train_pan_allele_models_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def setup_module():
'minibatch_size': 256,
'optimizer': 'rmsprop',
'output_activation': 'sigmoid',
'patience': 10,
'patience': 5,
'peptide_allele_merge_activation': '',
'peptide_allele_merge_method': 'concatenate',
'peptide_amino_acid_encoding': 'BLOSUM62',
Expand Down Expand Up @@ -185,7 +185,7 @@ def run_and_check(n_jobs=0, delete=True, additional_args=[]):
predictions = result.predict(peptides=["SLYNTVATL"],
alleles=["HLA-A*02:01"])
assert_equal(predictions.shape, (1,))
assert_array_less(predictions, 1000)
assert_array_less(predictions, 2000)

if delete:
print("Deleting: %s" % models_dir)
Expand Down

0 comments on commit 9eea537

Please sign in to comment.