Skip to content

Commit

Permalink
Simulation solver workflow (#2730)
Browse files Browse the repository at this point in the history
* Basic skeleton

* Lots of progress

Solvers pass info to each other
Simplified some parts

* Adds setup method, refactors some functions, removes unused items

* Move convergence setup to setup method

* Add docs notebook, fix a couple of errors

* Move setup to init

* Fix iteration and convergence

* Simplify spectrumsolver init

* Separate plasma "solver" and simulation state updates

* Simplify convergence solver setup with a dict.

* Minor formatting change

* Minor refactoring

* Fixes plasma update step

* Fixes loggers and progress bars

Also adds docstrings to methods. Updates some methods to use new functionality of the plasma. Adds requirements for the convergence plots (still broken)

* Fixes convergence plot rendering

* Fixes convergence plots in the final iteration

* black

* Simplify convergence plot updating

* Move logging handling to a separate class

* Add HDF output capability to solver

* Move more basic logging back into workflow

* Add not-converged error message

* Update notebook with export option

* Added simple base workflow and changed some verbiage

* Fix typo

* Black format

* Fixes spectrum solver test

* Some suggested refactoring

* Fix and rename workflows

* Apply black

* Match Jack's opacity state setup for easier fixing later

* Ruff formatting
  • Loading branch information
andrewfullard authored Oct 28, 2024
1 parent 494d625 commit 5f11626
Show file tree
Hide file tree
Showing 8 changed files with 1,081 additions and 5 deletions.
110 changes: 110 additions & 0 deletions docs/workflows/simple_workflow.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from tardis.io.configuration.config_reader import Configuration\n",
"from tardis.workflows.simple_tardis_workflow import SimpleTARDISWorkflow\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"config = Configuration.from_yaml('../tardis_example.yml')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"workflow = SimpleTARDISWorkflow(config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"workflow.run()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"spectrum = workflow.spectrum_solver.spectrum_real_packets\n",
"spectrum_virtual = workflow.spectrum_solver.spectrum_virtual_packets\n",
"spectrum_integrated = workflow.spectrum_solver.spectrum_integrated"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"plt.figure(figsize=(10, 6.5))\n",
"\n",
"spectrum.plot(label=\"Normal packets\")\n",
"spectrum_virtual.plot(label=\"Virtual packets\")\n",
"spectrum_integrated.plot(label='Formal integral')\n",
"\n",
"plt.xlim(500, 9000)\n",
"plt.title(\"TARDIS example model spectrum\")\n",
"plt.xlabel(\"Wavelength [$\\AA$]\")\n",
"plt.ylabel(\"Luminosity density [erg/s/$\\AA$]\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tardis",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
110 changes: 110 additions & 0 deletions docs/workflows/standard_workflow.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tardis.io.configuration.config_reader import Configuration\n",
"from tardis.workflows.standard_tardis_workflow import StandardTARDISWorkflow\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"config = Configuration.from_yaml('../tardis_example.yml')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"workflow = StandardTARDISWorkflow(config, show_convergence_plots=True,show_progress_bars=True,convergence_plots_kwargs={\"export_convergence_plots\":True})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"workflow.run()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"spectrum = workflow.spectrum_solver.spectrum_real_packets\n",
"spectrum_virtual = workflow.spectrum_solver.spectrum_virtual_packets\n",
"spectrum_integrated = workflow.spectrum_solver.spectrum_integrated"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"plt.figure(figsize=(10, 6.5))\n",
"\n",
"spectrum.plot(label=\"Normal packets\")\n",
"spectrum_virtual.plot(label=\"Virtual packets\")\n",
"spectrum_integrated.plot(label='Formal integral')\n",
"\n",
"plt.xlim(500, 9000)\n",
"plt.title(\"TARDIS example model spectrum\")\n",
"plt.xlabel(\"Wavelength [$\\AA$]\")\n",
"plt.ylabel(\"Luminosity density [erg/s/$\\AA$]\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tardis",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion tardis/spectrum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SpectrumSolver(HDFWriterMixin):
hdf_name = "spectrum"

def __init__(
self, transport_state, spectrum_frequency_grid, integrator_settings=None
self, transport_state, spectrum_frequency_grid, integrator_settings
):
self.transport_state = transport_state
self.spectrum_frequency_grid = spectrum_frequency_grid
Expand Down
8 changes: 4 additions & 4 deletions tardis/spectrum/tests/test_spectrum_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_initialization(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
solver = SpectrumSolver(transport_state, spectrum_frequency_grid, None)
assert solver.transport_state == transport_state
assert np.array_equal(
solver.spectrum_frequency_grid.value, spectrum_frequency_grid.value
Expand All @@ -61,7 +61,7 @@ def test_spectrum_real_packets(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
solver = SpectrumSolver(transport_state, spectrum_frequency_grid, None)
result = solver.spectrum_real_packets.luminosity
key = "simulation/spectrum_solver/spectrum_real_packets/luminosity"
expected = self.get_expected_data(key)
Expand All @@ -77,7 +77,7 @@ def test_spectrum_real_packets_reabsorbed(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
solver = SpectrumSolver(transport_state, spectrum_frequency_grid, None)
result = solver.spectrum_real_packets_reabsorbed.luminosity
key = "simulation/spectrum_solver/spectrum_real_packets_reabsorbed/luminosity"
expected = self.get_expected_data(key)
Expand All @@ -93,7 +93,7 @@ def test_solve(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
solver = SpectrumSolver(transport_state, spectrum_frequency_grid, None)
result_real, result_virtual, result_integrated = solver.solve(
transport_state
)
Expand Down
Empty file added tardis/workflows/__init__.py
Empty file.
Loading

0 comments on commit 5f11626

Please sign in to comment.