Skip to content

Commit

Permalink
Issue/model names (#98)
Browse files Browse the repository at this point in the history
* add hard coded model names

* lint

* fix tests
  • Loading branch information
peterdudfield authored Nov 13, 2024
1 parent 2ab69bd commit 1bb40b8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 15 deletions.
28 changes: 21 additions & 7 deletions src/india_api/internal/inputs/indiadb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_predicted_power_production_for_location(
self,
location: str,
asset_type: SiteAssetType,
ml_model_name: str,
forecast_horizon: ForecastHorizon = ForecastHorizon.latest,
forecast_horizon_minutes: Optional[int] = None,
smooth_flag: bool = True,
Expand All @@ -65,6 +66,7 @@ def get_predicted_power_production_for_location(
Args:
location: the location to get the predicted power production for
asset_type: The type of asset to get the forecast for
ml_model_name: The name of the model to get the forecast from
forecast_horizon: The time horizon to get the data for. Can be latest or day ahead
forecast_horizon_minutes: The number of minutes to get the forecast for. forecast_horizon must be 'horizon'
smooth_flag: Flag to smooth the forecast
Expand Down Expand Up @@ -109,6 +111,7 @@ def get_predicted_power_production_for_location(
day_ahead_hours=day_ahead_hours,
day_ahead_timezone_delta_hours=day_ahead_timezone_delta_hours,
forecast_horizon_minutes=forecast_horizon_minutes,
model_name=ml_model_name,
)
forecast_values: [ForecastValueSQL] = values[site.site_uuid]

Expand Down Expand Up @@ -183,12 +186,16 @@ def get_predicted_solar_power_production_for_location(
smooth_flag: Flag to smooth the forecast
"""

# set this to be hard coded for now
model_name = "pvnet_india"

return self.get_predicted_power_production_for_location(
location=location,
asset_type=SiteAssetType.pv,
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag,
ml_model_name=model_name,
)

def get_predicted_wind_power_production_for_location(
Expand All @@ -208,12 +215,16 @@ def get_predicted_wind_power_production_for_location(
smooth_flag: Flag to smooth the forecast
"""

# set this to be hard coded for now
model_name = "windnet_india"

return self.get_predicted_power_production_for_location(
location=location,
asset_type=SiteAssetType.wind,
forecast_horizon=forecast_horizon,
forecast_horizon_minutes=forecast_horizon_minutes,
smooth_flag=smooth_flag
smooth_flag=smooth_flag,
ml_model_name=model_name,
)

def get_actual_solar_power_production_for_location(
Expand Down Expand Up @@ -261,11 +272,14 @@ def get_sites(self, email: str) -> list[internal.Site]:

return sites

def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.PredictedPower]:
def get_site_forecast(self, site_uuid: str, email: str) -> list[internal.PredictedPower]:
"""Get a forecast for a site, this is for a solar site"""

# TODO feels like there is some duplicated code here which could be refactored

# hard coded model name
ml_model_name = "pvnet_ad_sites"

# Get the window
start, _ = get_window()

Expand All @@ -276,9 +290,7 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte
site_uuid = UUID(site_uuid)

values = get_latest_forecast_values_by_site(
session,
site_uuids=[site_uuid],
start_utc=start,
session, site_uuids=[site_uuid], start_utc=start, model_name=ml_model_name
)
forecast_values: [ForecastValueSQL] = values[site_uuid]

Expand All @@ -296,7 +308,7 @@ def get_site_forecast(self, site_uuid: str, email:str) -> list[internal.Predicte

return values

def get_site_generation(self, site_uuid: str, email:str) -> list[internal.ActualPower]:
def get_site_generation(self, site_uuid: str, email: str) -> list[internal.ActualPower]:
"""Get the generation for a site, this is for a solar site"""

# TODO feels like there is some duplicated code here which could be refactored
Expand Down Expand Up @@ -328,7 +340,9 @@ def get_site_generation(self, site_uuid: str, email:str) -> list[internal.Actual

return values

def post_site_generation(self, site_uuid: str, generation: list[internal.ActualPower], email:str):
def post_site_generation(
self, site_uuid: str, generation: list[internal.ActualPower], email: str
):
"""Post generation for a site"""

with self._get_session() as session:
Expand Down
32 changes: 27 additions & 5 deletions src/india_api/internal/inputs/indiadb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from pvsite_datamodel.sqlmodels import Base, ForecastSQL, ForecastValueSQL, GenerationSQL, SiteSQL
from pvsite_datamodel.read.user import get_user_by_email
from pvsite_datamodel.read.model import get_or_create_model
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from testcontainers.postgres import PostgresContainer
Expand Down Expand Up @@ -65,8 +66,8 @@ def sites(db_session):
ml_id=1,
asset_type="pv",
country="india",
region='testID',
client_site_name='ruvnl_pv_testID1'
region="testID",
client_site_name="ruvnl_pv_testID1",
)
db_session.add(site)
sites.append(site)
Expand All @@ -80,16 +81,16 @@ def sites(db_session):
ml_id=2,
asset_type="wind",
country="india",
region='testID',
client_site_name = 'ruvnl_wind_testID'
region="testID",
client_site_name="ruvnl_wind_testID",
)
db_session.add(site)
sites.append(site)

db_session.commit()

# create user
user = get_user_by_email(session=db_session, email='test@test.com')
user = get_user_by_email(session=db_session, email="test@test.com")
user.site_group.sites = sites

db_session.commit()
Expand Down Expand Up @@ -123,6 +124,23 @@ def generations(db_session, sites):
@pytest.fixture()
def forecast_values(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "pvnet_india")

@pytest.fixture()
def forecast_values_wind(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "windnet_india")

@pytest.fixture()
def forecast_values_site(db_session, sites):
"""Create some fake forecast values"""

make_fake_forecast_values(db_session, sites, "pvnet_ad_sites")


def make_fake_forecast_values(db_session, sites, model_name):
forecast_values = []
forecast_version: str = "0.0.0"

Expand All @@ -134,6 +152,9 @@ def forecast_values(db_session, sites):
# To make things trickier we make a second forecast at the same for one of the timestamps.
timestamps = timestamps + timestamps[-1:]

# get model
ml_model = get_or_create_model(db_session, model_name)

for site in sites:
for timestamp in timestamps:
forecast: ForecastSQL = ForecastSQL(
Expand All @@ -154,6 +175,7 @@ def forecast_values(db_session, sites):
end_utc=timestamp + timedelta(minutes=horizon + duration),
horizon_minutes=horizon,
)
forecast_value.ml_model = ml_model

forecast_values.append(forecast_value)

Expand Down
6 changes: 3 additions & 3 deletions src/india_api/internal/inputs/indiadb/test_indiadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def client(engine, db_session):

class TestIndiaDBClient:
def test_get_predicted_wind_power_production_for_location(
self, client, forecast_values
self, client, forecast_values_wind
) -> None:
locID = "testID"
result = client.get_predicted_wind_power_production_for_location(locID)
Expand All @@ -33,7 +33,7 @@ def test_get_predicted_wind_power_production_for_location(
assert isinstance(record, PredictedPower)

def test_get_predicted_wind_power_production_for_location_raise_error(
self, client, forecast_values
self, client, forecast_values_wind
) -> None:

with pytest.raises(Exception):
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_get_sites_no_sites(self, client, sites) -> None:
sites_from_api = client.get_sites(email="test2@test.com")
assert len(sites_from_api) == 0

def test_get_site_forecast(self, client, sites, forecast_values) -> None:
def test_get_site_forecast(self, client, sites, forecast_values_site) -> None:
out = client.get_site_forecast(site_uuid=str(sites[0].site_uuid), email="test@test.com")
assert len(out) > 0

Expand Down

0 comments on commit 1bb40b8

Please sign in to comment.