Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a source for convenient plotting of tabular data #24

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/earthkit/plots/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from earthkit.plots.sources.earthkit import EarthkitSource
from earthkit.plots.sources.numpy import NumpySource
from earthkit.plots.sources.tabular import TabularSource
from earthkit.plots.sources.xarray import XarraySource


Expand Down Expand Up @@ -56,4 +57,6 @@ def get_source(*args, data=None, x=None, y=None, z=None, u=None, v=None, **kwarg
cls = XarraySource
elif isinstance(core_data, ek_data.core.Base):
cls = EarthkitSource
elif core_data.__class__.__name__ in ("DataFrame", "Series"):
cls = TabularSource
return cls(*args, data=data, x=x, y=y, z=z, u=u, v=v, **kwargs)
120 changes: 120 additions & 0 deletions src/earthkit/plots/sources/tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2024, European Centre for Medium Range Weather Forecasts.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cached_property

import numpy as np

from earthkit.plots import identifiers
from earthkit.plots.sources.single import SingleSource


class TabularSource(SingleSource):
"""
Source class for tabular data.

Parameters
----------
data : xarray.Dataset
The data to be plotted.
x : str, optional
The x-coordinate variable in data.
y : str, optional
The y-coordinate variable in data.
z : str, optional
The z-coordinate variable in data.
u : str, optional
The u-component variable in data.
v : str, optional
The v-component variable in data.
crs : cartopy.crs.CRS, optional
The CRS of the data.
**kwargs
Metadata keys and values to attach to this Source.
"""

@cached_property
def data(self):
"""The underlying xarray data."""
# Promote a column (e.g., pandas or polars Series) to a DataFrame
if len(self._data.shape) == 1:
return self._data.to_frame()
return self._data

def metadata(self, key, default=None):
"""
Extract metadata from the data.

Parameters
----------
key : str
The metadata key to extract.
default : any, optional
The default value to return if the key is not found.
"""
if key == "variable_name":
# 2D data: use label of z column
if isinstance(self._z, str):
return self._z
# 1D data: use label of y column
if isinstance(self._y, str):
return self._y
return super().metadata(key, default)

@property
def _nrows(self):
return self.data.shape[0]

@property
def _ncols(self):
return self.data.shape[1]

def _column_values(self, name):
return self.data[name].to_numpy().squeeze()

@cached_property
def x_values(self):
# Column name specified explicitly or identified from standard set. Note
# that this means that identified columns take precedence over the index
# of a pandas DataFrame or Series.
if self._x is None:
self._x = identifiers.find_x(self.data.columns)
if self._x is not None:
return self._column_values(self._x)
# Table has an index (e.g., pandas.DataFrame)
if hasattr(self.data, "index"):
x = self.data.index
self._x = x.name
return x.to_numpy()
# Fallback: count upwards from 0
return np.arange(self._nrows)

@cached_property
def y_values(self):
# Column name specified explicitly or identified from standard set
if self._y is None:
self._y = identifiers.find_y(self.data.columns)
if self._y is not None:
return self._column_values(self._y)
# Single-column dataset
if self._ncols == 1:
self._y = self.data.columns[0]
return self.data.to_numpy().squeeze()
return None

@cached_property
def z_values(self):
if isinstance(self._z, str):
return self._column_values(self._z)
return None
65 changes: 65 additions & 0 deletions tests/sources/test_tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024, European Centre for Medium Range Weather Forecasts.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

try:
import pandas as pd
except ImportError:
pytest.skip(
"skipping tests in sources/tabular (no pandas)", allow_module_level=True
)

from earthkit.plots.sources.tabular import TabularSource


def test_TabularSource_Series():
series = pd.Series([4, 5, 7])
source = TabularSource(series)
assert np.array_equal(source.x_values, [0, 1, 2]) # auto-generated index
assert np.array_equal(source.y_values, [4, 5, 7])
assert source.z_values is None


def test_TabularSource_singlecol():
df = pd.DataFrame({"values": [4, 5, 7]})
source = TabularSource(df)
assert np.array_equal(source.x_values, [0, 1, 2]) # auto-generated index
assert np.array_equal(source.y_values, [4, 5, 7])
assert source.z_values is None


def test_TabularSource_multicol_identification():
df = pd.DataFrame({"x": [3, 4, 5], "y": [4, 5, 7]})
source = TabularSource(df)
assert np.array_equal(source.x_values, [3, 4, 5])
assert np.array_equal(source.y_values, [4, 5, 7])
assert source.z_values is None


def test_TabularSource_multicol_manual_2D():
df = pd.DataFrame({"foo": [4, 5, 6], "y": [3, 2, 1], "baz": [7, 8, 9]})
source = TabularSource(df, x="y", y="foo") # override y-detection
assert np.array_equal(source.x_values, [3, 2, 1])
assert np.array_equal(source.y_values, [4, 5, 6])
assert source.z_values is None


def test_TabularSource_multicol_manual_3D():
df = pd.DataFrame({"foo": [4, 5, 6], "bar": [3, 2, 1], "baz": [7, 8, 9]})
source = TabularSource(df, x="baz", y="foo", z="bar")
assert np.array_equal(source.x_values, [7, 8, 9])
assert np.array_equal(source.y_values, [4, 5, 6])
assert np.array_equal(source.z_values, [3, 2, 1])
Loading