diff --git a/tests/genomics/test_utils.py b/tests/genomics/test_utils.py index 96c73add..5e1cf0f3 100644 --- a/tests/genomics/test_utils.py +++ b/tests/genomics/test_utils.py @@ -69,88 +69,69 @@ def test_generate_mappings_genome_id_bgc_id_empty_dir(tmp_path, caplog): @pytest.fixture -def strain_collection() -> StrainCollection: - """Return a StrainCollection object.""" - sc = StrainCollection() - - strain = Strain("STRAIN_01") - strain.add_alias("BGC_01") - sc.add(strain) - - strain = Strain("STRAIN_02") - strain.add_alias("BGC_02") - strain.add_alias("BGC_02_1") - sc.add(strain) - - strain = Strain("SAMPLE_BGC_03") - sc.add(strain) - - return sc - - -@pytest.fixture -def bgc_list() -> list[BGC]: +def bgcs() -> list[BGC]: """Return a list of BGC objects.""" - return [BGC("BGC_01", "NPR"), BGC("BGC_02", "Alkaloid"), BGC("SAMPLE_BGC_03", "Polyketide")] - - -@pytest.fixture -def gcf_list() -> list[GCF]: - """Return a list of GCF objects.""" - gcf1 = GCF("1") - gcf1.bgc_ids |= {"BGC_01"} - gcf2 = GCF("2") - gcf2.bgc_ids |= {"BGC_02", "SAMPLE_BGC_03"} - return [gcf1, gcf2] - - -@pytest.fixture -def gcf_list_error() -> list[GCF]: - """Return a list of GCF objects for testing errors.""" - gcf1 = GCF("1") - gcf1.bgc_ids |= {"SAMPLE_BGC_03", "BGC_04"} - return [gcf1] + return [BGC("BGC_01", "NPR"), BGC("BGC_02", "Alkaloid"), BGC("BGC_03", "Polyketide")] -def test_add_strain_to_bgc(strain_collection, bgc_list): +def test_add_strain_to_bgc(bgcs): """Test add_strain_to_bgc function.""" - for bgc in bgc_list: - assert bgc.strain is None - add_strain_to_bgc(strain_collection, bgc_list) - for bgc in bgc_list: - assert bgc.strain is not None - assert bgc_list[0].strain.id == "STRAIN_01" - assert bgc_list[1].strain.id == "STRAIN_02" - assert bgc_list[2].strain.id == "SAMPLE_BGC_03" - - -def test_add_strain_to_bgc_error(strain_collection): + strain1 = Strain("STRAIN_01") + strain1.add_alias("BGC_01") + strain2 = Strain("STRAIN_02") + strain2.add_alias("BGC_02") + strain2.add_alias("BGC_02_1") + strain3 = Strain("STRAIN_03") + strains = StrainCollection() + strains.add(strain1) + strains.add(strain2) + strains.add(strain3) + + bgc_with_strain, bgc_without_strain = add_strain_to_bgc(strains, bgcs) + + assert len(bgc_with_strain) == 2 + assert len(bgc_without_strain) == 1 + assert bgc_with_strain == [bgcs[0], bgcs[1]] + assert bgc_without_strain == [bgcs[2]] + assert bgc_with_strain[0].strain == strain1 + assert bgc_with_strain[1].strain == strain2 + assert bgc_without_strain[0].strain is None + + +def test_add_strain_to_bgc_error(bgcs): """Test add_strain_to_bgc function error.""" - bgcs = [BGC("BGC_04", "NPR")] + strain1 = Strain("STRAIN_01") + strain1.add_alias("BGC_01") + strain2 = Strain("STRAIN_02") + strain2.add_alias("BGC_01") + strains = StrainCollection() + strains.add(strain1) + strains.add(strain2) + with pytest.raises(ValueError) as e: - add_strain_to_bgc(strain_collection, bgcs) - assert "Strain id 'BGC_04' from BGC object 'BGC_04' not found" in e.value.args[0] + add_strain_to_bgc(strains, bgcs) + + assert "Multiple strain objects found for BGC id 'BGC_01'" in e.value.args[0] -def test_add_bgc_to_gcf(bgc_list, gcf_list): +def test_add_bgc_to_gcf(bgcs): """Test add_bgc_to_gcf function.""" - assert gcf_list[0].bgc_ids == {"BGC_01"} - assert gcf_list[1].bgc_ids == {"BGC_02", "SAMPLE_BGC_03"} - assert len(gcf_list[0].bgcs) == 0 - assert len(gcf_list[1].bgcs) == 0 - add_bgc_to_gcf(bgc_list, gcf_list) - assert gcf_list[0].bgc_ids == {"BGC_01"} - assert gcf_list[1].bgc_ids == {"BGC_02", "SAMPLE_BGC_03"} - assert len(gcf_list[0].bgcs) == 1 - assert len(gcf_list[1].bgcs) == 2 - assert gcf_list[0].bgcs == set(bgc_list[:1]) - assert gcf_list[1].bgcs == set(bgc_list[1:]) - - -def test_add_bgc_to_gcf_error(bgc_list, gcf_list_error): - """Test add_bgc_to_gcf function error.""" - assert gcf_list_error[0].bgc_ids == {"SAMPLE_BGC_03", "BGC_04"} - assert len(gcf_list_error[0].bgcs) == 0 - with pytest.raises(KeyError) as e: - add_bgc_to_gcf(bgc_list, gcf_list_error) - assert "BGC id 'BGC_04' from GCF object '1' not found" in e.value.args[0] + gcf1 = GCF("1") + gcf1.bgc_ids = {"BGC_01", "BGC_02"} + gcf2 = GCF("2") + gcf2.bgc_ids = {"BGC_03", "BGC_missing_01"} + gcf3 = GCF("3") + gcf3.bgc_ids = {"BGC_missing_02", "BGC_missing_03"} + gcfs = [gcf1, gcf2, gcf3] + + gcf_with_bgc, gcf_without_bgc, gcf_missing_bgc = add_bgc_to_gcf(bgcs, gcfs) + + assert len(gcf_with_bgc) == 2 + assert len(gcf_without_bgc) == 1 + assert len(gcf_missing_bgc) == 2 + assert gcf_with_bgc == [gcf1, gcf2] + assert gcf_without_bgc == [gcf3] + assert gcf_missing_bgc == {gcf2: {"BGC_missing_01"}, gcf3: {"BGC_missing_02", "BGC_missing_03"}} + assert gcf_with_bgc[0].bgcs == {bgcs[0], bgcs[1]} + assert gcf_with_bgc[1].bgcs == {bgcs[2]} + assert gcf_without_bgc[0].bgcs == set()