Skip to content

Commit

Permalink
feat: add moving average plot (#836)
Browse files Browse the repository at this point in the history
Closes #XYZ

### Summary of Changes

<!-- Please provide a summary of changes in this pull request, ensuring
all changes are explained. -->
added moving average plot

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
Gerhardsa0 and megalinter-bot authored Jun 21, 2024
1 parent 2b82db7 commit abcf68a
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 34 deletions.
72 changes: 38 additions & 34 deletions docs/tutorials/data_visualization.ipynb

Large diffs are not rendered by default.

75 changes: 75 additions & 0 deletions src/safeds/data/tabular/plotting/_table_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,81 @@ def scatter_plot(self, x_name: str, y_names: list[str]) -> Image:

return _figure_to_image(fig)

def moving_average_plot(self, x_name: str, y_name: str, window_size: int) -> Image:
"""
Create a moving average plot for the y column and plot it by the x column in the table.
Parameters
----------
x_name:
The name of the column to be plotted on the x-axis.
y_name:
The name of the column to be plotted on the y-axis.
Returns
-------
plot:
The plot as an image.
Raises
------
ColumnNotFoundError
If a column does not exist.
TypeError
If a column is not numeric.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table(
... {
... "a": [1, 2, 3, 4, 5],
... "b": [2, 3, 4, 5, 6],
... }
... )
>>> image = table.plot.moving_average_plot("a", "b", window_size = 2)
"""
import matplotlib.pyplot as plt
import numpy as np
import polars as pl

_plot_validation(self._table, x_name, [y_name])
for name in [x_name, y_name]:
if self._table.get_column(name).missing_value_count() >= 1:
raise ValueError(
f"there are missing values in column '{name}', use transformation to fill missing values "
f"or drop the missing values. For a moving average no missing values are allowed.",
)

# Calculate the moving average
mean_col = pl.col(y_name).mean().alias(y_name)
grouped = self._table._lazy_frame.sort(x_name).group_by(x_name).agg(mean_col).collect()
data = grouped
moving_average = data.select([pl.col(y_name).rolling_mean(window_size).alias("moving_average")])
# set up the arrays for plotting
y_data_with_nan = moving_average["moving_average"].to_numpy()
nan_mask = ~np.isnan(y_data_with_nan)
y_data = y_data_with_nan[nan_mask]
x_data = data[x_name].to_numpy()[nan_mask]
fig, ax = plt.subplots()
ax.plot(x_data, y_data, label="moving average")
ax.set(
xlabel=x_name,
ylabel=y_name,
)
ax.legend()
if self._table.get_column(x_name).is_temporal:
ax.set_xticks(x_data) # Set x-ticks to the x data points
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(
ax.get_xticklabels(),
rotation=45,
horizontalalignment="right",
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
fig.tight_layout()

return _figure_to_image(fig)


def _plot_validation(table: Table, x_name: str, y_names: list[str]) -> None:
y_names.append(x_name)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions tests/safeds/data/tabular/plotting/test_moving_average_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import datetime

import pytest
from safeds.data.tabular.containers import Table
from safeds.exceptions import ColumnNotFoundError, ColumnTypeError
from syrupy import SnapshotAssertion


@pytest.mark.parametrize(
("table", "x_name", "y_name", "window_size"),
[
(Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", "B", 2),
# (Table({"A": [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], "B": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]}), "A", "B", 2),
(
Table(
{
"time": [
datetime.date(2022, 1, 10),
datetime.date(2022, 1, 10),
datetime.date(2022, 1, 11),
datetime.date(2022, 1, 11),
datetime.date(2022, 1, 12),
datetime.date(2022, 1, 12),
],
"A": [10, 5, 20, 2, 1, 1],
},
),
"time",
"A",
2,
),
(
Table(
{
"time": [
datetime.date(2022, 1, 9),
datetime.date(2022, 1, 10),
datetime.date(2022, 1, 11),
datetime.date(2022, 1, 12),
],
"A": [10, 5, 20, 2],
},
),
"time",
"A",
2,
),
],
ids=["numerical", "date grouped", "date"],
)
def test_should_match_snapshot(
table: Table,
x_name: str,
y_name: str,
window_size: int,
snapshot_png_image: SnapshotAssertion,
) -> None:
line_plot = table.plot.moving_average_plot(x_name, y_name, window_size)
assert line_plot == snapshot_png_image


@pytest.mark.parametrize(
("x", "y"),
[
("C", "A"),
("A", "C"),
("C", "D"),
],
ids=["x column", "y column", "x and y column"],
)
def test_should_raise_if_column_does_not_exist_error_message(x: str, y: str) -> None:
table = Table({"A": [1, 2, 3], "B": [2, 4, 7]})
with pytest.raises(ColumnNotFoundError):
table.plot.moving_average_plot(x, y, window_size=2)


@pytest.mark.parametrize(
("table"),
[
(Table({"A": [1, 2, 3], "B": ["2", 4, 7]})),
(Table({"A": ["1", 2, 3], "B": [2, 4, 7]})),
],
ids=["x column", "y column"],
)
def test_should_raise_if_column_is_not_numerical(table: Table) -> None:
with pytest.raises(ColumnTypeError):
table.plot.moving_average_plot("A", "B", window_size=2)


@pytest.mark.parametrize(
("table", "column_name"),
[
(Table({"A": [1, 2, 3], "B": [None, 4, 7]}), "B"),
(Table({"A": [None, 2, 3], "B": [2, 4, 7]}), "A"),
],
ids=["x column", "y column"],
)
def test_should_raise_if_column_has_missing_value(table: Table, column_name: str) -> None:
with pytest.raises(
ValueError,
match=f"there are missing values in column '{column_name}', use transformation to fill missing "
f"values or drop the missing values",
):
table.plot.moving_average_plot("A", "B", window_size=2)

0 comments on commit abcf68a

Please sign in to comment.