Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some new simple SFHs #120

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions cogsworth/sfh.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,60 @@ def save(self, file_name, key="sfh"):
file[key].attrs["params"] = yaml.dump(params, default_flow_style=None)


class BurstUniformDisc(StarFormationHistory):
"""An extremely simple star formation history, with all stars formed at ``t_burst`` in a uniform disc with
height ``z_max`` and radius ``R_max`` disc, all with metallicity ``Z``.

Parameters
----------

size : `int`
Number of points to sample from the model
t_burst : :class:`~astropy.units.Quantity` [time]
Lookback time at which all stars are formed
z_max : :class:`~astropy.units.Quantity` [length]
Maximum height of the disc
R_max : :class:`~astropy.units.Quantity` [length]
Maximum radius of the disc
Z : `float`, optional
Metallicity of the disc, by default 0.02
"""
def __init__(self, size, t_burst=12 * u.Gyr, z_max=2 * u.kpc, R_max=15 * u.kpc, Z_all=0.02, **kwargs):
self.t_burst = t_burst
self.z_max = z_max
self.R_max = R_max
self.Z_all = Z_all
super().__init__(size=size, components=["disc"], component_masses=[1], **kwargs)

def draw_lookback_times(self, size=None, component=None):
return np.repeat(self.t_burst.value, size) * self.t_burst.unit

def draw_radii(self, size=None, component=None):
return np.random.uniform(0, self.R_max.value**2, size)**(0.5) * self.R_max.unit

def draw_heights(self, size=None, component=None):
return np.random.uniform(-self.z_max.value, self.z_max.value, size) * self.z_max.unit

def draw_phi(self, size=None):
# if no size is given then use the class value
size = self._size if size is None else size
return np.random.uniform(0, 2 * np.pi, size) * u.rad

def get_metallicity(self):
return np.repeat(self.Z_all, self.size) * u.dimensionless_unscaled


class ConstantUniformDisc(BurstUniformDisc):
"""A simple star formation history, with all stars formed at a constant rate between ``t_burst``
and the present day in a uniform disc with height ``z_max`` and radius ``R_max`` disc, all with
metallicity ``Z``.

Based on :class:`BurstUniformDisc`.
"""
def draw_lookback_times(self, size=None, component=None):
return np.random.uniform(0, self.t_burst.value, size) * self.t_burst.unit


class Wagg2022(StarFormationHistory):
"""A semi-empirical model defined in
`Wagg+2022 <https://ui.adsabs.harvard.edu/abs/2021arXiv211113704W/abstract>`_
Expand Down
25 changes: 25 additions & 0 deletions cogsworth/tests/test_galaxy.py → cogsworth/tests/test_sfh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import cogsworth.sfh as sfh
import os
import astropy.units as u


class Test(unittest.TestCase):
Expand Down Expand Up @@ -45,6 +46,30 @@ def test_basic_class(self):
it_broke = True
self.assertTrue(it_broke)

def test_burst_uniform_disc(self):
"""Ensure the burst uniform disc class works"""
g = sfh.BurstUniformDisc(size=10000,
t_burst=5 * u.Gyr,
R_max=20 * u.kpc,
z_max=1 * u.kpc,
Z=0.02)
self.assertTrue(np.all(g.tau == 5 * u.Gyr))
self.assertTrue(np.all(g.z <= 1 * u.kpc))
self.assertTrue(np.all(g.rho <= 20 * u.kpc))
self.assertTrue(np.all(g.Z == 0.02))

def test_constant_uniform_disc(self):
"""Ensure the constant uniform disc class works"""
g = sfh.ConstantUniformDisc(size=10000,
t_burst=5 * u.Gyr,
R_max=20 * u.kpc,
z_max=1 * u.kpc,
Z=0.02)
self.assertTrue(np.all(g.tau <= 5 * u.Gyr))
self.assertTrue(np.all(g.z <= 1 * u.kpc))
self.assertTrue(np.all(g.rho <= 20 * u.kpc))
self.assertTrue(np.all(g.Z == 0.02))

def test_bad_inputs(self):
"""Ensure the classes fail with bad input"""
g = sfh.Wagg2022(size=None, immediately_sample=False)
Expand Down
Loading