Skip to content

Commit

Permalink
Compute linear regression between given year interval ; closes #76
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-florentin-charles committed Sep 21, 2024
1 parent 8d23635 commit dcaabc8
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 25 deletions.
17 changes: 13 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,21 @@ def get_rainfall_averages(
"/graph/rainfall_linreg_slopes",
response_class=StreamingResponse,
summary="Retrieve rainfall monthly or seasonal linear regression slopes of data as a PNG.",
description=f"Time mode should be either '{TimeMode.MONTHLY.value}' or '{TimeMode.SEASONAL.value}'.",
description=f"Time mode should be either '{TimeMode.MONTHLY.value}' or '{TimeMode.SEASONAL.value}'.\n"
f"If no ending year is precised, most recent year available is taken: {all_rainfall.get_last_year()}.",
tags=["Graph"],
operation_id="getRainfallLinregSlopes",
)
def get_rainfall_linreg_slopes(time_mode: TimeMode):
linreg_slopes = all_rainfall.bar_rainfall_linreg_slopes(time_mode.value)
def get_rainfall_linreg_slopes(
time_mode: TimeMode,
begin_year: int,
end_year: int | None = None,
):
end_year = end_year or all_rainfall.get_last_year()

linreg_slopes = all_rainfall.bar_rainfall_linreg_slopes(
time_mode=time_mode.value, begin_year=begin_year, end_year=end_year
)
if linreg_slopes is None:
raise HTTPException(
status_code=400,
Expand All @@ -362,7 +371,7 @@ def get_rainfall_linreg_slopes(time_mode: TimeMode):
plt.close()
img_buffer.seek(0)

filename = f"rainfall_{time_mode.value}_linreg_slopes_{all_rainfall.starting_year}_{all_rainfall.get_last_year()}.png"
filename = f"rainfall_{time_mode.value}_linreg_slopes_{begin_year}_{end_year}.png"

return StreamingResponse(
img_buffer,
Expand Down
22 changes: 19 additions & 3 deletions src/core/models/all_rainfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,21 +360,37 @@ def bar_rainfall_averages(

return None

def bar_rainfall_linreg_slopes(self, time_mode: str) -> list | None:
def bar_rainfall_linreg_slopes(
self,
time_mode: str,
begin_year: int,
end_year: int | None = None,
) -> list | None:
"""
Plots a bar graphic displaying linear regression slope for each month or each season.
:param time_mode: A string setting the time period ['monthly', 'seasonal'].
:param begin_year: An integer representing the year
to start getting our rainfall values.
:param end_year: An integer representing the year
to end getting our rainfall values (optional).
Is set to last year available is None.
:return: A list of the Rainfall LinReg slopes for each month or season.
None if time_mode is not within {'monthly', 'seasonal'}.
"""
end_year = end_year or self.get_last_year()

if time_mode == TimeMode.MONTHLY.value:
return plotting.bar_monthly_rainfall_linreg_slopes(
list(self.monthly_rainfalls.values())
list(self.monthly_rainfalls.values()),
begin_year=begin_year,
end_year=end_year,
)
elif time_mode == TimeMode.SEASONAL.value:
return plotting.bar_seasonal_rainfall_linreg_slopes(
list(self.seasonal_rainfalls.values())
list(self.seasonal_rainfalls.values()),
begin_year=begin_year,
end_year=end_year,
)

return None
Expand Down
31 changes: 31 additions & 0 deletions src/core/models/yearly_rainfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,37 @@ def get_standard_deviation(
self.round_precision,
)

def get_linear_regression(
self, begin_year: int, end_year: int | None = None
) -> tuple[float, float]:
"""
Computes Linear Regression of rainfall according to year for a given time interval.
:param begin_year: An integer representing the year
to start getting our rainfall values.
:param end_year: An integer representing the year
to end getting our rainfall values (optional).
If not given, defaults to latest year available.
:return: a tuple containing two floats (r2 score, slope).
"""
end_year = end_year or self.get_last_year()

data = self.get_yearly_rainfall(begin_year, end_year)

years = data[Label.YEAR.value].values.reshape(-1, 1) # type: ignore
rainfalls = data[Label.RAINFALL.value].values

lin_reg = LinearRegression()
lin_reg.fit(years, rainfalls)
predicted_rainfalls = [
round(rainfall_value, self.round_precision)
for rainfall_value in lin_reg.predict(years).tolist()
]

return r2_score(rainfalls, predicted_rainfalls), round(
lin_reg.coef_[0], self.round_precision
)

def add_percentage_of_normal(
self, begin_year: int, end_year: int | None = None
) -> None:
Expand Down
38 changes: 22 additions & 16 deletions src/core/utils/functions/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,28 @@ def bar_monthly_rainfall_averages(

def bar_monthly_rainfall_linreg_slopes(
monthly_rainfalls: list,
begin_year: int,
end_year: int,
) -> list:
"""
Plots a bar graphic displaying linear regression slope for each month passed through the dict.
If list is empty, does not plot anything and returns an empty list.
:param monthly_rainfalls: A list of instances of MonthlyRainfall.
To be purposeful, all instances should have the same time frame in years.
:param begin_year: An integer representing the year
to start getting our rainfall values.
:param end_year: An integer representing the year
to end getting our rainfall values.
:return: A list of the Rainfall LinReg slopes for each month.
"""
if not monthly_rainfalls:
return []

month_labels, slopes = [], []
for monthly_rainfall in monthly_rainfalls:
month_labels.append(monthly_rainfall.month.value[:3])
slopes.append(monthly_rainfall.add_linear_regression()[1])

begin_year = monthly_rainfalls[0].starting_year
end_year = monthly_rainfalls[0].get_last_year()
slopes.append(
monthly_rainfall.get_linear_regression(
begin_year=begin_year, end_year=end_year
)[1]
)

plt.bar(
month_labels,
Expand Down Expand Up @@ -188,25 +191,28 @@ def bar_seasonal_rainfall_averages(

def bar_seasonal_rainfall_linreg_slopes(
seasonal_rainfalls: list,
begin_year: int,
end_year: int,
) -> list:
"""
Plots a bar graphic displaying linear regression slope for each season passed through the dict.
If list is empty, does not plot anything and returns an empty list.
:param seasonal_rainfalls: A list of instances of SeasonalRainfall.
To be purposeful, all instances should have the same time frame in years.
:param begin_year: An integer representing the year
to start getting our rainfall values.
:param end_year: An integer representing the year
to end getting our rainfall values.
:return: A list of the Rainfall LinReg slopes for each season.
"""
if not seasonal_rainfalls:
return []

season_labels, slopes = [], []
for seasonal_rainfall in seasonal_rainfalls:
season_labels.append(seasonal_rainfall.season.value)
slopes.append(seasonal_rainfall.add_linear_regression()[1])

begin_year = seasonal_rainfalls[0].starting_year
end_year = seasonal_rainfalls[0].get_last_year()
slopes.append(
seasonal_rainfall.get_linear_regression(
begin_year=begin_year, end_year=end_year
)[1]
)

plt.bar(
season_labels,
Expand Down
7 changes: 7 additions & 0 deletions tst/core/models/test_yearly_rainfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ def test_get_standard_deviation():

assert isinstance(std_weighted_by_avg, float)

@staticmethod
def test_get_linear_regression():
r2_score, slope = YEARLY_RAINFALL.get_linear_regression(begin_year, end_year)

assert isinstance(r2_score, float) and r2_score <= 1
assert isinstance(slope, float)

@staticmethod
def test_add_percentage_of_normal():
YEARLY_RAINFALL.add_percentage_of_normal(YEARLY_RAINFALL.starting_year)
Expand Down
9 changes: 7 additions & 2 deletions tst/core/utils/functions/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from src.core.utils.enums.labels import Label
from src.core.utils.functions import plotting
from tst.core.models.test_all_rainfall import begin_year, end_year
from tst.core.models.test_yearly_rainfall import YEARLY_RAINFALL, ALL_RAINFALL

BEGIN_YEAR = 1970
Expand Down Expand Up @@ -48,7 +49,9 @@ def test_bar_monthly_rainfall_averages():
@staticmethod
def test_bar_monthly_rainfall_linreg_slopes():
slopes = plotting.bar_monthly_rainfall_linreg_slopes(
list(ALL_RAINFALL.monthly_rainfalls.values())
list(ALL_RAINFALL.monthly_rainfalls.values()),
begin_year=begin_year,
end_year=end_year,
)

assert isinstance(slopes, list)
Expand All @@ -71,7 +74,9 @@ def test_bar_seasonal_rainfall_averages():
@staticmethod
def test_bar_seasonal_rainfall_linreg_slopes():
slopes = plotting.bar_seasonal_rainfall_linreg_slopes(
list(ALL_RAINFALL.seasonal_rainfalls.values())
list(ALL_RAINFALL.seasonal_rainfalls.values()),
begin_year=begin_year,
end_year=end_year,
)

assert isinstance(slopes, list)
Expand Down

0 comments on commit dcaabc8

Please sign in to comment.