From 66f234bc5b99ca79d4ba517b4ea0e9f5d721bbaf Mon Sep 17 00:00:00 2001
From: ndaelman <ndaelman@physik.hu-berlin.de>
Date: Tue, 28 Jan 2025 16:13:39 +0100
Subject: [PATCH] Add plotly visualization

---
 .../schema_packages/properties/electronic.py  | 44 ++++++++++++++++---
 1 file changed, 37 insertions(+), 7 deletions(-)

diff --git a/src/nomad_simulations/schema_packages/properties/electronic.py b/src/nomad_simulations/schema_packages/properties/electronic.py
index 40070c1d..7d3d4d7f 100644
--- a/src/nomad_simulations/schema_packages/properties/electronic.py
+++ b/src/nomad_simulations/schema_packages/properties/electronic.py
@@ -3,7 +3,10 @@
 import numpy as np
 from nomad.config import config
 from nomad.metainfo import Quantity, SchemaPackage, Section, SubSection
+from nomad.datamodel.metainfo.plot import PlotSection, PlotlyFigure
 from nomad_simulations.schema_packages.general import ModelBaseSection
+import plotly.graph_objects as go
+import plotly.express as px
 
 configuration = config.get_plugin_entry_point(
     'nomad_simulations.schema_packages:nomad_simulations_plugin'
@@ -15,8 +18,10 @@
 class Spin(ModelBaseSection):
     pass
 
-class DOS(ModelBaseSection):
+
+class DOS(PlotSection, ModelBaseSection):
     """Collection of Electronic Density of States"""
+
     m_def = Section()
 
     class SemanticDOS(ModelBaseSection):
@@ -28,33 +33,34 @@ class SpinResolvedDOS(ModelBaseSection):
 
             spin = Quantity(
                 type=Spin,
-                description="Spin channel",
+                description='Spin channel',
             )
 
             values = Quantity(
                 type=np.float64,
                 shape=['*'],
-                description="Actual DOS values",
+                description='Actual DOS values',
             )  # ? add renormalized_values
 
-            def name_from_section(self, section):
+            def name_from_section(self, section) -> str:
                 return self.spin.name_from_section()
 
         label = Quantity(
             type=str,
-            description="Label of the DOS",
+            default='total',
+            description='Label of the DOS',
         )  # TODO: el n m
 
         energies = Quantity(
             type=np.float64,
             unit='J',
             shape=['*'],
-            description="Energy values at which the DOS is evaluated",
+            description='Energy values at which the DOS is evaluated',
         )
 
         spin_channels = SubSection(subsection=SpinResolvedDOS.m_def, repeats=True)
 
-        def name_from_section(self, section):
+        def name_from_section(self, section) -> str:
             if section.label:
                 return section.label
             else:
@@ -62,5 +68,29 @@ def name_from_section(self, section):
 
     collections = SubSection(subsection=SemanticDOS.m_def, repeats=True)
 
+    def generate_plot(self) -> go.Figure:
+        fig = go.Figure()
+        for collection in self.collections:
+            for spin_channel in collection.spin_channels:
+                fig.add_trace(
+                    px.line(
+                        x=collection.energies,
+                        y=spin_channel.values,
+                        name=f'{collection.name_from_section()} {spin_channel.spin.name_from_section()}',
+                    )
+                )
+        return fig
+    
+    def normalize(self, archive, logger):
+        super().normalize(archive, logger)
+        # this does not check if the plot was already stored
+        self.figures.append(
+            PlotlyFigure(
+                label='Full DOS',
+                index=0,
+                figure=self.generate_plot().to_plotly_json(),
+            )
+        )
+
 
 m_package.__init_metainfo__()