Skip to content

Commit

Permalink
only one type of plot -> inside ArrheniusRegressor
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandezfran committed Jul 12, 2023
1 parent 1c5aee7 commit dda8187
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 31 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ include LICENSE
include README.md
include pyproject.toml
include sierras.py
include test_sierras.py

exclude requirements_dev.txt
exclude tox.ini
exclude .readthedocs.yml

recursive-exclude tests *
recursive-exclude docs *
34 changes: 5 additions & 29 deletions sierras.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@
# CONSTANTS
# =============================================================================

__all__ = ["ArrheniusRegressor", "ArrheniusPlotter"]

__all__ = ["ArrheniusRegressor"]

NAME = "sierras"

Expand Down Expand Up @@ -188,26 +187,8 @@ def to_dataframe(self, X, y, sample_weight=None):

return df

@property
def plot(self):
"""Arrhenius plot accessor."""
return ArrheniusPlotter(self)


class ArrheniusPlotter:
"""Arrhenius plot.
Parameters
----------
areg : sierras.ArrheniusRegressor
An ArrheniusRegressor already fitted.
"""

def __init__(self, areg):
self.areg = areg

def arrhenius(self, X, y, ax=None, data_kws=None, pred_kws=None):
"""Arrhenius plot function.
def plot(self, ax=None, data_kws=None, pred_kws=None):
"""Arrhenius plotter.
Parameters
----------
Expand Down Expand Up @@ -237,12 +218,7 @@ def arrhenius(self, X, y, ax=None, data_kws=None, pred_kws=None):

pred_kws.setdefault("label", "fit")

ax.errorbar(
self.areg._X,
self.areg._y,
yerr=self.areg._sample_weight,
**data_kws,
)
ax.plot(self.areg._X, self.areg.reg_.predict(self.areg._X), **pred_kws)
ax.errorbar(self._X, self._y, yerr=self._sample_weight, **data_kws)
ax.plot(self._X, self.reg_.predict(self._X), **pred_kws)

return ax
2 changes: 1 addition & 1 deletion test_sierras.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_plot(self, fig_test, fig_ref, dataset, request):

# test
test_ax = fig_test.subplots()
areg.plot.arrhenius(X, y, ax=test_ax)
areg.plot(ax=test_ax)

# expected
ref_ax = fig_ref.subplots()
Expand Down

0 comments on commit dda8187

Please sign in to comment.