Skip to content

Commit

Permalink
Merge branch 'main' into 1076-grm-skipna
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored May 2, 2023
2 parents 56dec88 + e28c0f7 commit 01eb580
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
5 changes: 4 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ New Features
- Add :func:`sgkit.io.vcf.read_vcf` convenience function.
(:user:`tomwhite`, :pr:`1052`, :issue:`1004`)

- Add :func:`sgkit.hybrid_relationship`, :func:`sgkit.hybrid_inverse_relationship`
- Add :func:`sgkit.hybrid_relationship`, :func:`sgkit.hybrid_inverse_relationship`
and :func:`invert_relationship_matrix` methods.
(:user:`timothymillar`, :pr:`1053`, :issue:`993`)

Expand All @@ -45,6 +45,9 @@ New Features
- Add ``skipna`` option to :func:`genomic_relationship` function.
(:user:`timothymillar`, :pr:`1078`, :issue:`1076`)

- Add `additional_variant_fields` to :func:`sgkit.simulate_genotype_call_dataset` function.
(:user:`benjeffery`, :pr:`1056`)

Bug fixes
~~~~~~~~~

Expand Down
22 changes: 21 additions & 1 deletion sgkit/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def simulate_genotype_call_dataset(
seed: Optional[int] = 0,
missing_pct: Optional[float] = None,
phased: Optional[bool] = None,
additional_variant_fields: Optional[dict] = None,
) -> Dataset:
"""Simulate genotype calls and variant/sample data.
Expand Down Expand Up @@ -50,6 +51,9 @@ def simulate_genotype_call_dataset(
The percentage of missing calls, must be within [0.0, 1.0], optional
phased
Whether genotypes are phased, default is unphased, optional
additional_variant_fields
Additional variant fields to add to the dataset as a dictionary of
{field_name: field_dtype}, optional
Returns
-------
Expand All @@ -62,6 +66,7 @@ def simulate_genotype_call_dataset(
- :data:`sgkit.variables.call_genotype_spec` (variants, samples, ploidy)
- :data:`sgkit.variables.call_genotype_mask_spec` (variants, samples, ploidy)
- :data:`sgkit.variables.call_genotype_phased_spec` (variants, samples), if ``phased`` is not None
- Those specified in ``additional_variant_fields``, if provided
"""
if missing_pct and (missing_pct < 0.0 or missing_pct > 1.0):
raise ValueError("missing_pct must be within [0.0, 1.0]")
Expand All @@ -87,7 +92,7 @@ def simulate_genotype_call_dataset(
["A", "C", "G", "T"], size=(n_variant, n_allele)
).astype("S")
sample_id = np.array([f"S{i}" for i in range(n_sample)])
return create_genotype_call_dataset(
ds = create_genotype_call_dataset(
variant_contig_names=contig_names,
variant_contig=contig,
variant_position=position,
Expand All @@ -96,3 +101,18 @@ def simulate_genotype_call_dataset(
call_genotype=call_genotype,
call_genotype_phased=call_genotype_phased,
)
# Add in each of the additional variant fields, if provided with random data
if additional_variant_fields is not None:
for field_name, field_dtype in additional_variant_fields.items():
if field_dtype in (np.float32, np.float64):
field = rs.rand(n_variant).astype(field_dtype)
elif field_dtype in (np.int8, np.int16, np.int32, np.int64):
field = rs.randint(0, 100, n_variant, dtype=field_dtype)
elif field_dtype is np.bool:
field = rs.rand(n_variant) > 0.5
elif field_dtype is np.str:
field = np.arange(n_variant).astype("S")
else:
raise ValueError(f"Unrecognized dtype {field_dtype}")
ds[field_name] = (("variants",), field)
return ds
32 changes: 32 additions & 0 deletions sgkit/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,35 @@ def test_simulate_genotype_call_dataset__phased(tmp_path):
ds = simulate_genotype_call_dataset(n_variant=10, n_sample=10, phased=False)
assert "call_genotype_phased" in ds
assert not np.any(ds["call_genotype_phased"])


def test_simulate_genotype_call_dataset__additional_variant_fields():
ds = simulate_genotype_call_dataset(
n_variant=10,
n_sample=10,
phased=True,
additional_variant_fields={
"variant_id": np.str,
"variant_filter": np.bool,
"variant_quality": np.int8,
"variant_yummyness": np.float32,
},
)
assert "variant_id" in ds
assert np.all(ds["variant_id"] == np.arange(10).astype("S"))
assert "variant_filter" in ds
assert ds["variant_filter"].dtype == np.bool
assert "variant_quality" in ds
assert ds["variant_quality"].dtype == np.int8
assert "variant_yummyness" in ds
assert ds["variant_yummyness"].dtype == np.float32

with pytest.raises(ValueError, match="Unrecognized dtype"):
simulate_genotype_call_dataset(
n_variant=10,
n_sample=10,
phased=True,
additional_variant_fields={
"variant_id": None,
},
)

0 comments on commit 01eb580

Please sign in to comment.