diff --git a/popsborder/contamination.py b/popsborder/contamination.py index 3aad66a2..ff6bcb14 100644 --- a/popsborder/contamination.py +++ b/popsborder/contamination.py @@ -316,6 +316,37 @@ def add_contaminant_clusters_to_boxes(config, consignment): ) + +def add_contaminant_clusters_to_items_with_subset_clustering(config, consignment): + """Add contaminant clusters to items in a consignment + + Clustering equal to 0 means all items in the consignment can be contaminated with + equal probability, i.e., the cluster spreads over the whole consignment. Clustering + equal to 1 means that all items in the cluster are contaminated. The size of the + cluster is then directly determined by the contamination rate. + """ + clustering = config["clustered"]["clustering"] + num_of_contaminated_items = num_items_to_contaminate( + config["contamination_rate"], consignment.num_items + ) + if num_of_contaminated_items == 0: + return + subset_size = round(consignment.num_items * (1 - clustering)) + if subset_size < num_of_contaminated_items: + subset_size = num_of_contaminated_items + if subset_size == consignment.num_items: + start_index=0 + else: + start_index = np.random.randint(0, consignment.num_items - subset_size) + indexes = np.random.choice( + range(start_index, start_index + subset_size), + num_of_contaminated_items, + replace=False + ) + consignment.items[indexes] = 1 + assert np.count_nonzero(consignment.items) == num_of_contaminated_items + + def add_contaminant_clusters_to_items(config, consignment): """Add contaminant clusters to items in a consignment""" contaminated_units_per_cluster = config["clustered"][ @@ -385,9 +416,14 @@ def add_contaminant_clusters(config, consignment): """ contamination_unit = config["contamination_unit"] if contamination_unit in ["box", "boxes"]: + if config["clustered"]["distribution"] == "subset": + raise RuntimeError("clustering distribution 'subset' is not supported for boxes") add_contaminant_clusters_to_boxes(config, consignment) elif contamination_unit in ["item", "items"]: - add_contaminant_clusters_to_items(config, consignment) + if config["clustered"]["distribution"] == "subset": + add_contaminant_clusters_to_items_with_subset_clustering(config, consignment) + else: + add_contaminant_clusters_to_items(config, consignment) else: raise RuntimeError(f"Unknown contamination unit: {contamination_unit}") diff --git a/tests/test_contaminant_clusters_basic_distributions.py b/tests/test_contaminant_clusters_basic_distributions.py index bdd40666..074209e3 100644 --- a/tests/test_contaminant_clusters_basic_distributions.py +++ b/tests/test_contaminant_clusters_basic_distributions.py @@ -1,6 +1,7 @@ import datetime import numpy as np +import pytest from popsborder.consignments import Box, Consignment from popsborder.contamination import add_contaminant_clusters @@ -74,3 +75,26 @@ def test_random_clusters(): contamination_rate = 0.12 contaminated_items = int(num_items * contamination_rate) assert np.count_nonzero(consignment.items) == contaminated_items + + +@pytest.mark.parametrize("contamination_rate", [0.0, 0.1, 0.2, 0.5, 0.8, 1.0]) +@pytest.mark.parametrize("clustering", [0.0, 0.2, 0.5, 0.8, 1.0]) +def test_add_contaminant_clusters_to_items_with_subset_clustering(contamination_rate, clustering): + """Test contaminant clusters with subset clustering""" + config_yaml = f""" + contamination: + contamination_rate: + distribution: fixed_value + value: {contamination_rate} + contamination_unit: item + arrangement: clustered + clustered: + distribution: subset + clustering: {clustering} + """ + config = load_configuration_yaml_from_text(config_yaml)["contamination"] + num_items = 100 + consignment = get_consignment(num_items) + add_contaminant_clusters(config, consignment) + contaminated_items = int(num_items * contamination_rate) + assert np.count_nonzero(consignment.items) == contaminated_items