From 66f234bc5b99ca79d4ba517b4ea0e9f5d721bbaf Mon Sep 17 00:00:00 2001 From: ndaelman <ndaelman@physik.hu-berlin.de> Date: Tue, 28 Jan 2025 16:13:39 +0100 Subject: [PATCH] Add plotly visualization --- .../schema_packages/properties/electronic.py | 44 ++++++++++++++++--- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/nomad_simulations/schema_packages/properties/electronic.py b/src/nomad_simulations/schema_packages/properties/electronic.py index 40070c1d..7d3d4d7f 100644 --- a/src/nomad_simulations/schema_packages/properties/electronic.py +++ b/src/nomad_simulations/schema_packages/properties/electronic.py @@ -3,7 +3,10 @@ import numpy as np from nomad.config import config from nomad.metainfo import Quantity, SchemaPackage, Section, SubSection +from nomad.datamodel.metainfo.plot import PlotSection, PlotlyFigure from nomad_simulations.schema_packages.general import ModelBaseSection +import plotly.graph_objects as go +import plotly.express as px configuration = config.get_plugin_entry_point( 'nomad_simulations.schema_packages:nomad_simulations_plugin' @@ -15,8 +18,10 @@ class Spin(ModelBaseSection): pass -class DOS(ModelBaseSection): + +class DOS(PlotSection, ModelBaseSection): """Collection of Electronic Density of States""" + m_def = Section() class SemanticDOS(ModelBaseSection): @@ -28,33 +33,34 @@ class SpinResolvedDOS(ModelBaseSection): spin = Quantity( type=Spin, - description="Spin channel", + description='Spin channel', ) values = Quantity( type=np.float64, shape=['*'], - description="Actual DOS values", + description='Actual DOS values', ) # ? add renormalized_values - def name_from_section(self, section): + def name_from_section(self, section) -> str: return self.spin.name_from_section() label = Quantity( type=str, - description="Label of the DOS", + default='total', + description='Label of the DOS', ) # TODO: el n m energies = Quantity( type=np.float64, unit='J', shape=['*'], - description="Energy values at which the DOS is evaluated", + description='Energy values at which the DOS is evaluated', ) spin_channels = SubSection(subsection=SpinResolvedDOS.m_def, repeats=True) - def name_from_section(self, section): + def name_from_section(self, section) -> str: if section.label: return section.label else: @@ -62,5 +68,29 @@ def name_from_section(self, section): collections = SubSection(subsection=SemanticDOS.m_def, repeats=True) + def generate_plot(self) -> go.Figure: + fig = go.Figure() + for collection in self.collections: + for spin_channel in collection.spin_channels: + fig.add_trace( + px.line( + x=collection.energies, + y=spin_channel.values, + name=f'{collection.name_from_section()} {spin_channel.spin.name_from_section()}', + ) + ) + return fig + + def normalize(self, archive, logger): + super().normalize(archive, logger) + # this does not check if the plot was already stored + self.figures.append( + PlotlyFigure( + label='Full DOS', + index=0, + figure=self.generate_plot().to_plotly_json(), + ) + ) + m_package.__init_metainfo__()