Skip to content

Commit

Permalink
Merge pull request #1 from FAIRmat-NFDI/NXmpes_arpes_import
Browse files Browse the repository at this point in the history
NXmpes arpes import
  • Loading branch information
rettigl authored Apr 4, 2024
2 parents b30378e + e5fa444 commit 6e967fc
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 40 deletions.
134 changes: 110 additions & 24 deletions arpes/endstations/plugin/nexus.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
from typing import Any, Dict, Union
import xarray as xr
from collections.abc import Sequence
from typing import Optional, Union

import h5py
import numpy as np
from arpes.endstations import SingleFileEndstation, add_endstation
import xarray as xr
from pint import Quantity

from arpes.config import ureg
from arpes.endstations import SingleFileEndstation, add_endstation

__all__ = ["NeXusEndstation"]

nexus_translation_table = {
"sample/transformations/trans_x": "x",
"sample/transformations/trans_y": "y",
"sample/transformations/trans_z": "z",
"sample/transformations/sample_polar": "theta",
"sample/transformations/offset_polar": "theta_offset",
"sample/transformations/sample_tilt": "beta",
"sample/transformations/offset_tilt": "beta_offset",
"sample/transformations/sample_azimuth": "chi",
"sample/transformations/offset_azimuth": "chi_offset",
"instrument/beam_probe/incident_energy": "hv",
"instrument/electronanalyser/work_function": "work_function",
"instrument/electronanalyser/transformations/analyzer_rotation": "alpha",
"instrument/electronanalyser/transformations/analyzer_elevation": "psi",
"instrument/electronanalyser/transformations/analyzer_dispersion": "phi",
"instrument/electronanalyser/energydispersion/kinetic_energy": "eV",
}


class NeXusEndstation(SingleFileEndstation):
"""An endstation for reading arpes data from a nexus file."""

Expand All @@ -16,29 +39,36 @@ class NeXusEndstation(SingleFileEndstation):
".nxs",
}


def load_nexus_file(self, filepath: str, entry_name: str = "entry") -> xr.DataArray:
"""Loads a MPES NeXus file and creates a DataArray from it.
"""
Loads an MPES NeXus file and creates a DataArray from it.
Args:
filepath (str): The path of the .nxs file.
entry_name (str, optional):
The name of the entry to process. Defaults to "entry".
Raises:
KeyError:
Thrown if dependent axis are not found in the nexus file.
Returns:
xr.DataArray: The data read from the .nxs file.
"""

def write_value(name: str, dataset: h5py.Dataset):
if str(dataset.dtype) == 'bool':
if str(dataset.dtype) == "bool":
attributes[name] = bool(dataset[()])
elif dataset.dtype.kind in 'iufc':
elif dataset.dtype.kind in "iufc":
attributes[name] = dataset[()]
if 'units' in dataset.attrs:
attributes[name] = attributes[name] * ureg(dataset.attrs['units'])
if "units" in dataset.attrs:
attributes[name] = attributes[name] * ureg(dataset.attrs["units"])
elif dataset.dtype.kind in "O" and dataset.shape == ():
attributes[name] = dataset[()].decode()

def is_valid_metadata(name: str) -> bool:
invalid_end_paths = ['depends_on']
invalid_start_paths = ['data', 'process']
invalid_end_paths = ["depends_on"]
invalid_start_paths = ["data", "process"]
for invalid_path in invalid_start_paths:
if name.startswith(invalid_path):
return False
Expand All @@ -48,29 +78,85 @@ def is_valid_metadata(name: str) -> bool:
return True

def parse_attrs(name: str, dataset: Union[h5py.Dataset, h5py.Group]):
short_path = name.split('/', 1)[-1]
short_path = name.split("/", 1)[-1]
if isinstance(dataset, h5py.Dataset) and is_valid_metadata(short_path):
write_value(short_path, dataset)

def translate_nxmpes_to_pyarpes(attributes: dict) -> dict:
for key, newkey in nexus_translation_table.items():
if key in attributes:
try:
if attributes[key].units == "degree":
attributes[newkey] = attributes[key].to(ureg.rad)
else:
attributes[newkey] = attributes[key]
except AttributeError:
attributes[newkey] = attributes[key]
# flip sign of offsets, as they are subtracted in pyARPES rather than added
if newkey.find("offset") > -1:
attributes[newkey] *= -1

# remove axis arrays from static coordinates:
for axis in self.ENSURE_COORDS_EXIST:
if axis in attributes and (
isinstance(attributes[axis], (Sequence, np.ndarray))
or (
isinstance(attributes[axis], Quantity)
and (
isinstance(
attributes[axis].magnitude, (Sequence, np.ndarray)
)
)
)
):
if len(attributes[axis]) > 0:
attributes[axis] = attributes[axis][0]

return attributes

def load_nx_data(nxdata: h5py.Group, attributes: dict) -> xr.DataArray:
axes = nxdata.attrs["axes"]

# handle moving axes
new_axes = []
for axis in axes:
if f"{axis}_depends" not in nxdata.attrs:
raise KeyError(f"Dependent axis field not found for axis {axis}.")

axis_depends: str = nxdata.attrs[f"{axis}_depends"]
axis_depends_key = axis_depends.split("/", 2)[-1]
new_axes.append(nexus_translation_table[axis_depends_key])
if nexus_translation_table[axis_depends_key] in attributes:
attributes.pop(nexus_translation_table[axis_depends_key])

coords = {}
for axis, new_axis in zip(axes, new_axes):
coords[new_axis] = nxdata[axis][:] * ureg(nxdata[axis].attrs["units"])
if coords[new_axis].units == "degree":
coords[new_axis] = coords[new_axis].to(ureg.rad)
data = nxdata[nxdata.attrs["signal"]][:]
dims = new_axes

dataset = xr.DataArray(data, coords=coords, dims=dims, attrs=attributes)

return dataset

data_path = f"/{entry_name}/data"
with h5py.File(filepath, "r") as h5file:
attributes = {}
h5file.visititems(parse_attrs)
return xr.DataArray(
h5file[f"/{data_path}/data"][:],
coords={
"delay": h5file[f"{data_path}/delay"][:],
"eV": np.transpose(h5file[f"{data_path}/energy"][:]),
"kx": h5file[f"{data_path}/kx"][:],
"ky": h5file[f"{data_path}/ky"][:],
},
dims=["kx", "ky", "eV", "delay"],
attrs=attributes
)
attributes = translate_nxmpes_to_pyarpes(attributes)
dataset = load_nx_data(h5file[data_path], attributes)
return dataset

def load_single_frame(
self, frame_path: str = None, scan_desc: dict = None, **kwargs
self,
frame_path: Optional[str] = None,
scan_desc: Optional[dict] = None,
**kwargs,
) -> xr.Dataset:
if frame_path is None:
return xr.Dataset()
data = self.load_nexus_file(frame_path)
return xr.Dataset({"spectrum": data}, attrs=data.attrs)

Expand Down
2 changes: 1 addition & 1 deletion arpes/fits/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def broadcast_model(

other_axes = set(data.dims).difference(set(broadcast_dims))
template = data.sum(list(other_axes))
template.values = np.ndarray(template.shape, dtype=np.object)
template.values = np.ndarray(template.shape, dtype=object)
n_fits = np.prod(np.array(list(template.S.dshape.values())))

if parallelize is None:
Expand Down
2 changes: 1 addition & 1 deletion arpes/utilities/conversion/bounds_calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def calculate_kx_ky_bounds(arr: xr.DataArray):
beta_mid,
]
)
kinetic_energy = arr.coords["eV"].values.max()
kinetic_energy = arr.S.hv - arr.S.work_function + arr.coords["eV"].values.max()

kxs = arpes.constants.K_INV_ANGSTROM * np.sqrt(kinetic_energy) * np.sin(sampled_phi_values)
kys = (
Expand Down
30 changes: 24 additions & 6 deletions arpes/utilities/conversion/kx_ky_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
Broadly, this covers cases where we are not performing photon energy scans.
"""
import numpy as np
import math
from typing import Any, Callable, Dict, List

import numba
import math
import numpy as np
import pint
import xarray as xr

import arpes.constants
import xarray as xr
from typing import Any, Callable, Dict, List

from .base import CoordinateConverter, K_SPACE_BORDER, MOMENTUM_BREAKPOINTS
from .base import K_SPACE_BORDER, MOMENTUM_BREAKPOINTS, CoordinateConverter
from .bounds_calculations import calculate_kp_bounds, calculate_kx_ky_bounds

__all__ = ["ConvertKp", "ConvertKxKy"]
Expand Down Expand Up @@ -74,6 +75,11 @@ def _safe_compute_k_tot(hv, work_function, binding_energy):

return k_tot

def _strip_units(val):
if isinstance(val, pint.Quantity):
return val.magnitude
return val


class ConvertKp(CoordinateConverter):
"""A momentum converter for single ARPES (kp) cuts."""
Expand Down Expand Up @@ -141,6 +147,12 @@ def kspace_to_phi(
)
parallel_angle = self.arr.S.lookup_offset_coord("theta")

offset = self.arr.S.phi_offset + parallel_angle

polar_angle = _strip_units(polar_angle)
parallel_angle = _strip_units(parallel_angle)
offset = _strip_units(offset)

if self.k_tot is None:
self.compute_k_tot(binding_energy)

Expand All @@ -152,7 +164,7 @@ def kspace_to_phi(
kp / np.cos(polar_angle),
self.k_tot,
self.phi,
self.arr.S.phi_offset + parallel_angle,
offset,
par_tot,
False,
)
Expand Down Expand Up @@ -309,6 +321,7 @@ def rkx_rky(self, kx, ky):
return self.rkx, self.rky

chi = self.arr.S.lookup_offset_coord("chi")
chi = _strip_units(chi)

self.rkx = np.zeros_like(kx)
self.rky = np.zeros_like(ky)
Expand All @@ -335,6 +348,7 @@ def kspace_to_phi(
scan_angle = self.direct_angles[1]
self.phi = np.zeros_like(ky)
offset = self.arr.S.phi_offset + self.arr.S.lookup_offset_coord(self.parallel_angles[0])
offset = _strip_units(offset)

par_tot = isinstance(self.k_tot, np.ndarray) and len(self.k_tot) != 1
assert len(self.k_tot) == len(self.phi) or len(self.k_tot) == 1
Expand Down Expand Up @@ -386,21 +400,25 @@ def kspace_to_perp_angle(
offset = self.arr.S.psi_offset - self.arr.S.lookup_offset_coord(
self.parallel_angles[1]
)
offset = _strip_units(offset)
_small_angle_arcsin(kx, self.k_tot, self.perp_angle, offset, par_tot, True)
else:
offset = self.arr.S.psi_offset + self.arr.S.lookup_offset_coord(
self.parallel_angles[1]
)
offset = _strip_units(offset)
_small_angle_arcsin(ky, self.k_tot, self.perp_angle, offset, par_tot, False)
elif scan_angle == "beta":
offset = self.arr.S.beta_offset + self.arr.S.lookup_offset_coord(
self.parallel_angles[1]
)
offset = _strip_units(offset)
_exact_arcsin(ky, kx, self.k_tot, self.perp_angle, offset, par_tot, True)
elif scan_angle == "theta":
offset = self.arr.S.theta_offset - self.arr.S.lookup_offset_coord(
self.parallel_angles[1]
)
offset = _strip_units(offset)
_exact_arcsin(kx, ky, self.k_tot, self.perp_angle, offset, par_tot, True)
else:
raise ValueError(
Expand Down
16 changes: 8 additions & 8 deletions arpes/xarray_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@
The main accessors are .S, .G, .X. and .F.
The `.S` accessor:
The `.S` accessor contains functionality related to spectroscopy. Utilities
The `.S` accessor contains functionality related to spectroscopy. Utilities
which only make sense in this context should be placed here, while more generic
tools should be placed elsewhere.
The `.G.` accessor:
This a general purpose collection of tools which exists to provide conveniences
over what already exists in the xarray data model. As an example, there are
various tools for simultaneous iteration of data and coordinates here, as well as
over what already exists in the xarray data model. As an example, there are
various tools for simultaneous iteration of data and coordinates here, as well as
for vectorized application of functions to data or coordinates.
The `.X` accessor:
This is an accessor which contains tools related to selecting and subselecting
The `.X` accessor:
This is an accessor which contains tools related to selecting and subselecting
data. The two most notable tools here are `.X.first_exceeding` which is very useful
for initializing curve fits and `.X.max_in_window` which is useful for refining
for initializing curve fits and `.X.max_in_window` which is useful for refining
these initial parameter choices.
The `.F.` accessor:
Expand Down Expand Up @@ -2687,7 +2687,7 @@ def p(self, param_name: str) -> xr.DataArray:
The output array is infilled with `np.nan` if the fit did not converge/
the fit result is `None`.
"""
return self._obj.G.map(param_getter(param_name), otypes=[np.float])
return self._obj.G.map(param_getter(param_name), otypes=[float])

def s(self, param_name: str) -> xr.DataArray:
"""Collects the standard deviation of a parameter from fitting.
Expand All @@ -2704,7 +2704,7 @@ def s(self, param_name: str) -> xr.DataArray:
The output array is infilled with `np.nan` if the fit did not converge/
the fit result is `None`.
"""
return self._obj.G.map(param_stderr_getter(param_name), otypes=[np.float])
return self._obj.G.map(param_stderr_getter(param_name), otypes=[float])

@property
def bands(self) -> Dict[str, MultifitBand]:
Expand Down

0 comments on commit 6e967fc

Please sign in to comment.