Skip to content

Commit

Permalink
Fix and rename workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Oct 7, 2024
1 parent f18a3c2 commit 8f24d7f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 77 deletions.
12 changes: 6 additions & 6 deletions docs/workflows/simple_workflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from tardis.workflows.simple_simulation import SimpleSimulation\n",
"from tardis.workflows.simple_tardis_workflow import SimpleTARDISWorkflow\n",
"from tardis.io.configuration.config_reader import Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,7 +25,7 @@
"metadata": {},
"outputs": [],
"source": [
"workflow = SimpleSimulation(config)"
"workflow = SimpleTARDISWorkflow(config)"
]
},
{
Expand All @@ -39,7 +39,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -102,7 +102,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions docs/workflows/standard_workflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
"metadata": {},
"outputs": [],
"source": [
"from tardis.workflows.standard_simulation import StandardSimulation\n",
"from tardis.workflows.standard_tardis_workflow import StandardTARDISWorkflow\n",
"from tardis.io.configuration.config_reader import Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,7 +25,7 @@
"metadata": {},
"outputs": [],
"source": [
"workflow = StandardSimulation(config, show_convergence_plots=True,show_progress_bars=True,convergence_plots_kwargs={\"export_convergence_plots\":True})"
"workflow = StandardTARDISWorkflow(config, show_convergence_plots=True,show_progress_bars=True,convergence_plots_kwargs={\"export_convergence_plots\":True})"
]
},
{
Expand All @@ -39,7 +39,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -102,7 +102,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from astropy import units as u

from tardis import constants as const
from tardis.io.atom_data.base import AtomData
from tardis.io.model.parse_atom_data import parse_atom_data
from tardis.model import SimulationState
from tardis.opacities.macro_atom.macroatom_solver import MacroAtomSolver
from tardis.opacities.opacity_solver import OpacitySolver
from tardis.plasma.assembly.legacy_assembly import assemble_plasma
from tardis.plasma.radiation_field import DilutePlanckianRadiationField
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.simulation.convergence import ConvergenceSolver
from tardis.spectrum.base import SpectrumSolver
from tardis.spectrum.formal_integral import FormalIntegrator
Expand All @@ -24,22 +26,39 @@
logger = logging.getLogger(__name__)


class SimpleSimulation(WorkflowLogging):
class SimpleTARDISWorkflow(WorkflowLogging):
show_progress_bars = is_notebook()
enable_virtual_packet_logging = False
log_level = None
specific_log_level = None

def __init__(self, configuration):
super().__init__(configuration, self.log_level, self.specific_log_level)
atom_data = self._get_atom_data(configuration)
atom_data = parse_atom_data(configuration)

self.line_interaction_type = configuration.plasma.line_interaction_type

# set up states and solvers
self.simulation_state = SimulationState.from_config(
configuration,
atom_data=atom_data,
)

self.opacity_state = None
self.opacity_solver = OpacitySolver(
self.line_interaction_type,
configuration.plasma.disable_line_scattering,
)

self.macro_atom_state = None
if self.line_interaction_type in (
"downbranch",
"macroatom",
):
self.macro_atom_solver = MacroAtomSolver()
else:
self.macro_atom_solver = None

self.plasma_solver = assemble_plasma(
configuration,
self.simulation_state,
Expand Down Expand Up @@ -117,49 +136,6 @@ def __init__(self, configuration):
self.convergence_strategy.t_inner
)

def _get_atom_data(self, configuration):
"""Process atomic data from the configuration
Parameters
----------
configuration : Configuration
TARDIS configuration object
Returns
-------
AtomData
Atomic data object
Raises
------
ValueError
If atom data is missing from the configuration
"""
if "atom_data" in configuration:
if Path(configuration.atom_data).is_absolute():
atom_data_fname = Path(configuration.atom_data)
else:
atom_data_fname = (
Path(configuration.config_dirname) / configuration.atom_data
)

else:
raise ValueError("No atom_data option found in the configuration.")

logger.info(f"\n\tReading Atomic Data from {atom_data_fname}")

try:
atom_data = AtomData.from_hdf(atom_data_fname)
except TypeError:
logger.exception(
"TypeError might be from the use of an old-format of the atomic database, \n"
"please see https://github.com/tardis-sn/tardis-refdata/tree/master/atom_data"
" for the most recent version.",
)
raise

return atom_data

def get_convergence_estimates(self, transport_state):
"""Compute convergence estimates from the transport state
Expand All @@ -185,12 +161,8 @@ def get_convergence_estimates(self, transport_state):
)
)

estimated_t_radiative = (
estimated_radfield_properties.dilute_blackbody_radiationfield_state.temperature
)
estimated_dilution_factor = (
estimated_radfield_properties.dilute_blackbody_radiationfield_state.dilution_factor
)
estimated_t_radiative = estimated_radfield_properties.dilute_blackbody_radiationfield_state.temperature
estimated_dilution_factor = estimated_radfield_properties.dilute_blackbody_radiationfield_state.dilution_factor

emitted_luminosity = calculate_filtered_luminosity(
transport_state.emitted_packet_nu,
Expand Down Expand Up @@ -369,6 +341,8 @@ def solve_montecarlo(self, no_of_real_packets, no_of_virtual_packets=0):
"""
transport_state = self.transport_solver.initialize_transport_state(
self.simulation_state,
self.opacity_state,
self.macro_atom_state,
self.plasma_solver,
no_of_real_packets,
no_of_virtual_packets=no_of_virtual_packets,
Expand Down Expand Up @@ -406,9 +380,9 @@ def initialize_spectrum_solver(
self.spectrum_solver.transport_state = transport_state

if virtual_packet_energies is not None:
self.spectrum_solver._montecarlo_virtual_luminosity.value[
:
] = virtual_packet_energies
self.spectrum_solver._montecarlo_virtual_luminosity.value[:] = (
virtual_packet_energies
)

if self.integrated_spectrum_settings is not None:
# Set up spectrum solver integrator
Expand All @@ -426,6 +400,17 @@ def run(self):
logger.info(
f"\n\tStarting iteration {(self.completed_iterations + 1):d} of {self.total_iterations:d}"
)

self.opacity_state = self.opacity_solver.solve(self.plasma_solver)

if self.macro_atom_solver is not None:
self.macro_atom_state = self.macro_atom_solver.solve(
self.plasma_solver,
self.plasma_solver.atomic_data,
self.opacity_state.tau_sobolev,
self.plasma_solver.stimulated_emission_factor,
)

transport_state, virtual_packet_energies = self.solve_montecarlo(
self.real_packet_count
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
)
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots
from tardis.workflows.simple_simulation import SimpleSimulation
from tardis.workflows.simple_tardis_workflow import SimpleTARDISWorkflow

# logging support
logger = logging.getLogger(__name__)


class StandardSimulation(
SimpleSimulation, PlasmaStateStorerMixin, HDFWriterMixin
class StandardTARDISWorkflow(
SimpleTARDISWorkflow, PlasmaStateStorerMixin, HDFWriterMixin
):
convergence_plots = None
export_convergence_plots = False
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
self.enable_virtual_packet_logging = enable_virtual_packet_logging
self.convergence_plots_kwargs = convergence_plots_kwargs

SimpleSimulation.__init__(self, configuration)
SimpleTARDISWorkflow.__init__(self, configuration)

# set up plasma storage
PlasmaStateStorerMixin.__init__(
Expand Down Expand Up @@ -130,12 +130,8 @@ def get_convergence_estimates(self, transport_state):
)
)

estimated_t_radiative = (
estimated_radfield_properties.dilute_blackbody_radiationfield_state.temperature
)
estimated_dilution_factor = (
estimated_radfield_properties.dilute_blackbody_radiationfield_state.dilution_factor
)
estimated_t_radiative = estimated_radfield_properties.dilute_blackbody_radiationfield_state.temperature
estimated_dilution_factor = estimated_radfield_properties.dilute_blackbody_radiationfield_state.dilution_factor

emitted_luminosity = calculate_filtered_luminosity(
transport_state.emitted_packet_nu,
Expand Down Expand Up @@ -222,6 +218,17 @@ def run(self):
self.plasma_solver.electron_densities,
self.simulation_state.t_inner,
)

self.opacity_state = self.opacity_solver.solve(self.plasma_solver)

if self.macro_atom_solver is not None:
self.macro_atom_state = self.macro_atom_solver.solve(
self.plasma_solver,
self.plasma_solver.atomic_data,
self.opacity_state.tau_sobolev,
self.plasma_solver.stimulated_emission_factor,
)

transport_state, virtual_packet_energies = self.solve_montecarlo(
self.real_packet_count
)
Expand Down

0 comments on commit 8f24d7f

Please sign in to comment.