From 6bf0dbfd34f1a0a91f4cbfb13f4a3898df34377a Mon Sep 17 00:00:00 2001 From: Jurkash Date: Fri, 19 Jul 2024 02:01:23 +0300 Subject: [PATCH] Rework calendar feature --- custom_components/loe_outages/__init__.py | 2 +- custom_components/loe_outages/api.py | 65 ++++++++++++++----- custom_components/loe_outages/calendar.py | 8 ++- custom_components/loe_outages/coordinator.py | 66 ++++++++------------ custom_components/loe_outages/manifest.json | 12 ++-- custom_components/loe_outages/models.py | 44 +++++++------ 6 files changed, 111 insertions(+), 86 deletions(-) diff --git a/custom_components/loe_outages/__init__.py b/custom_components/loe_outages/__init__.py index 0db8186..01af8d8 100644 --- a/custom_components/loe_outages/__init__.py +++ b/custom_components/loe_outages/__init__.py @@ -27,7 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: entry.runtime_data = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - entry.async_on_unload(entry.add_update_listener(coordinator.update_config)) + entry.async_on_unload(entry.add_update_listener(coordinator.async_update_config)) return True diff --git a/custom_components/loe_outages/api.py b/custom_components/loe_outages/api.py index f2bd68c..3416d91 100644 --- a/custom_components/loe_outages/api.py +++ b/custom_components/loe_outages/api.py @@ -4,7 +4,7 @@ import aiohttp import datetime import pytz -from .models import OutageSchedule +from .models import OutageSchedule, Interval LOGGER = logging.getLogger(__name__) @@ -46,15 +46,15 @@ async def async_fetch_all_json(self) -> dict: async def async_fetch_schedules(self) -> None: """Fetch outages from the JSON response.""" if len(self.schedules) == 0: + LOGGER.debug("Fetching all schedules") schedules_data = await self.async_fetch_all_json() schedules = OutageSchedule.from_list(schedules_data) for schedule in sorted(schedules, key=lambda s: s.date): self.schedules.append(schedule) - return + return None else: + LOGGER.debug("Fetching latest schedules") schedule_data = await self.async_fetch_latest_json() - schedule = OutageSchedule.from_dict(schedule_data) - new_schedule = OutageSchedule.from_dict(schedule_data) self.schedules = [ item @@ -63,38 +63,71 @@ async def async_fetch_schedules(self) -> None: ] self.schedules.append(new_schedule) self.schedules.sort(key=lambda item: item.date) + LOGGER.debug("Saved schedules %s", list(map(lambda s: s.date, self.schedules))) + return None - def get_current_event(self, at: datetime) -> dict: + def get_current_event(self, at: datetime.datetime) -> Interval | None: """Get the current event.""" - if not self.schedules: + if not self.schedules or len(self.schedules) == 0: + LOGGER.debug("No schedules found") return None - twoDaysBefore = datetime.datetime.now() + datetime.timedelta(days=-2) + at = at.astimezone(pytz.UTC) + twoDaysBefore = (at + datetime.timedelta(days=-2)).astimezone(pytz.UTC) for schedule in reversed(self.schedules): - if schedule.date < twoDaysBefore.astimezone(pytz.UTC): + LOGGER.debug("Schedule to compare: %s < %s", schedule.date, twoDaysBefore) + if schedule.date < twoDaysBefore: return None events_at = schedule.get_current_event(self.group, at) - if not events_at: - return None - return events_at # return only the first event + if events_at is not None: + LOGGER.debug("Some event was found: %s", events_at) + return events_at # return only the first event + + LOGGER.debug("No evets at found") + return None def get_events( self, start_date: datetime.datetime, end_date: datetime.datetime, - ) -> list[dict]: + ) -> list[Interval]: """Get all events.""" - if not self.schedules: + if not self.schedules or len(self.schedules) == 0: return [] + start_date = start_date.astimezone(pytz.UTC) + end_date = end_date.astimezone(pytz.UTC) result = [] - twoDaysBeforeStart = start_date + datetime.timedelta(days=-2) + twoDaysBeforeStart = (start_date + datetime.timedelta(days=-2)).astimezone( + pytz.UTC + ) for schedule in reversed(self.schedules): if schedule.date < twoDaysBeforeStart: break - for interval in schedule.between(self.group, start_date, end_date): + for interval in schedule.intersect(self.group, start_date, end_date): result.append(interval) - return result + return self._merge_intervals(sorted(result, key=lambda i: i.startTime)) + + def _merge_intervals(self, intervals: list[Interval]) -> list[Interval]: + if not intervals: + return [] + + # Start with the first interval + merged_intervals = [intervals[0]] + + for current in intervals[1:]: + last = merged_intervals[-1] + if last.endTime == current.startTime and last.state == current.state: + merged_intervals[-1] = Interval( + startTime=last.startTime, endTime=current.endTime, state=last.state + ) + else: + merged_intervals.append(current) + [ + LOGGER.debug("merged: from: %s, to: %s", inter.startTime, inter.endTime) + for inter in merged_intervals + ] + return merged_intervals diff --git a/custom_components/loe_outages/calendar.py b/custom_components/loe_outages/calendar.py index fad6849..8b2acec 100644 --- a/custom_components/loe_outages/calendar.py +++ b/custom_components/loe_outages/calendar.py @@ -2,13 +2,13 @@ import datetime import logging +import pytz from homeassistant.components.calendar import CalendarEntity, CalendarEvent from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.util import dt as dt_utils from .coordinator import LoeOutagesCoordinator from .entity import LoeOutagesEntity @@ -50,9 +50,11 @@ def __init__( @property def event(self) -> CalendarEvent | None: """Return the current or next upcoming event or None.""" - now = dt_utils.now() + utc = pytz.UTC + now = datetime.datetime.now().astimezone(utc) LOGGER.debug("Getting current event for %s", now) - return self.coordinator.get_calendar_at(now) + res = self.coordinator.get_calendar_at(now) + return res async def async_get_events( self, diff --git a/custom_components/loe_outages/coordinator.py b/custom_components/loe_outages/coordinator.py index dc823d9..27a195a 100644 --- a/custom_components/loe_outages/coordinator.py +++ b/custom_components/loe_outages/coordinator.py @@ -58,7 +58,7 @@ def event_name_map(self) -> dict: STATE_ON: self.translations.get(TRANSLATION_KEY_EVENT_ON), } - async def update_config( + async def async_update_config( self, hass: HomeAssistant, # noqa: ARG002 config_entry: ConfigEntry, @@ -96,12 +96,12 @@ async def async_fetch_translations(self) -> None: def _get_next_event_of_type(self, state_type: str) -> Interval | None: """Get the next event of a specific type.""" - now = dt_utils.now() + now = dt_utils.now().astimezone(pytz.UTC) # Sort events to handle multi-day spanning events correctly next_events = sorted( self.get_intervals_between( now, - now + TIMEFRAME_TO_CHECK, + (now + TIMEFRAME_TO_CHECK).astimezone(pytz.UTC), translate=False, ), key=lambda event: event.startTime, @@ -122,13 +122,6 @@ def next_outage(self) -> datetime.datetime | None: @property def next_connectivity(self) -> datetime.datetime | None: """Get next connectivity time.""" - now = dt_utils.now() - current_event = self.get_interval_at(now) - # If current event is OFF, return the end time - if self._event_to_state(current_event) == STATE_OFF: - return current_event.endTime - - # Otherwise, return the next on event's end event = self._get_next_event_of_type(STATE_ON) LOGGER.debug("Next connectivity: %s", event) return event.startTime if event else None @@ -136,11 +129,11 @@ def next_connectivity(self) -> datetime.datetime | None: @property def current_state(self) -> str: """Get the current state.""" - now = dt_utils.now() + now = dt_utils.now().astimezone(pytz.UTC) event = self.get_interval_at(now) return self._event_to_state(event) - def get_interval_at(self, at: datetime.datetime) -> Interval: + def get_interval_at(self, at: datetime.datetime) -> Interval | None: """Get the current event.""" event = self.api.get_current_event(at) return self._get_interval_event(event, translate=False) @@ -160,33 +153,31 @@ def get_intervals_between( def _get_interval_event( self, - event: dict | None, + interval: Interval | None, *, translate: bool = True, ) -> Interval: """Transform an event into a Inteval.""" - if not event: + if not interval: return None - event_summary = event["state"] - translated_summary = self.event_name_map.get(event_summary) - event_start = event["startTime"] - event_end = event["endTime"] + interval_summary = interval.state + translated_summary = self.event_name_map.get(interval_summary) LOGGER.debug( "Transforming event: %s (%s -> %s)", - event_summary, - event_start, - event_end, + interval_summary, + interval.startTime, + interval.endTime, ) return Interval( - state=translated_summary if translate else event_summary, - startTime=event_start, - endTime=event_end, + state=translated_summary if translate else interval_summary, + startTime=interval.startTime, + endTime=interval.endTime, ) - def get_calendar_at(self, at: datetime.datetime) -> CalendarEvent: + def get_calendar_at(self, at: datetime.datetime) -> CalendarEvent | None: """Get the current event.""" event = self.api.get_current_event(at) return self._get_calendar_event(event, translate=False) @@ -206,32 +197,29 @@ def get_calendar_between( def _get_calendar_event( self, - event: dict | None, + interval: Interval | None, *, translate: bool = True, ) -> CalendarEvent: """Transform an event into a Inteval.""" - if not event: + if not interval: return None - local_tz = pytz.timezone("Europe/Kyiv") - event_summary = event["state"] - translated_summary = self.event_name_map.get(event_summary) - event_start = event["startTime"].astimezone(local_tz) - event_end = event["endTime"].astimezone(local_tz) + interval_summary = interval.state + translated_summary = self.event_name_map.get(interval_summary) LOGGER.debug( "Transforming event: %s (%s -> %s)", - event_summary, - event_start, - event_end, + interval_summary, + interval.startTime, + interval.endTime, ) return CalendarEvent( - summary=translated_summary if translate else event_summary, - start=event_start, - end=event_end, - description=event_summary, + summary=translated_summary if translate else interval_summary, + start=interval.startTime, + end=interval.endTime, + description=interval_summary, ) def _event_to_state(self, event: Interval | None) -> str: diff --git a/custom_components/loe_outages/manifest.json b/custom_components/loe_outages/manifest.json index 84f520c..20633a4 100644 --- a/custom_components/loe_outages/manifest.json +++ b/custom_components/loe_outages/manifest.json @@ -1,15 +1,11 @@ { "domain": "loe_outages", "name": "LOE Outages", - "codeowners": [ - "@jurkash" - ], + "codeowners": ["@jurkash"], "config_flow": true, "documentation": "https://github.com/jurkash/ha-loe-outages", "iot_class": "calculated", "issue_tracker": "https://github.com/jurkash/ha-loe-outages", - "requirements": [ - - ], - "version": "0.0.3" -} \ No newline at end of file + "requirements": [], + "version": "0.1.0" +} diff --git a/custom_components/loe_outages/models.py b/custom_components/loe_outages/models.py index fb97659..f7d2b4b 100644 --- a/custom_components/loe_outages/models.py +++ b/custom_components/loe_outages/models.py @@ -1,13 +1,15 @@ import datetime import pytz -from dateutil import parser -from typing import List +import logging utc = pytz.UTC +LOGGER = logging.getLogger(__name__) class Interval: - def __init__(self, state: str, startTime: datetime, endTime: datetime): + def __init__( + self, state: str, startTime: datetime.datetime, endTime: datetime.datetime + ): self.state = state self.startTime = startTime self.endTime = endTime @@ -16,8 +18,10 @@ def __init__(self, state: str, startTime: datetime, endTime: datetime): def from_dict(obj: dict) -> "Interval": return Interval( state=obj.get("state").lower(), - startTime=parser.parse(obj.get("startTime")).astimezone(utc), - endTime=parser.parse(obj.get("endTime")).astimezone(utc), + startTime=datetime.datetime.fromisoformat(obj.get("startTime")).astimezone( + utc + ), + endTime=datetime.datetime.fromisoformat(obj.get("endTime")).astimezone(utc), ) def to_dict(self) -> dict: @@ -29,7 +33,7 @@ def to_dict(self) -> dict: class Group: - def __init__(self, id: str, intervals: List[Interval]): + def __init__(self, id: str, intervals: list[Interval]): self.id = id self.intervals = intervals @@ -51,10 +55,10 @@ class OutageSchedule: def __init__( self, id: str, - date: datetime, + date: datetime.datetime, dateString: str, imageUrl: str, - groups: List[Group], + groups: list[Group], ): self.id = id self.date = date @@ -63,15 +67,15 @@ def __init__( self.groups = groups @staticmethod - def from_list(obj_list: List[dict]) -> List["OutageSchedule"]: + def from_list(obj_list: list[dict]) -> list["OutageSchedule"]: return [OutageSchedule.from_dict(item) for item in obj_list] @staticmethod - def from_dict(obj: dict) -> list["OutageSchedule"]: + def from_dict(obj: dict) -> "OutageSchedule": groups = [Group.from_dict(group) for group in obj.get("groups", [])] return OutageSchedule( id=obj.get("id"), - date=parser.parse(obj.get("date")).astimezone(utc), + date=datetime.datetime.fromisoformat(obj.get("date")).astimezone(utc), dateString=obj.get("dateString"), imageUrl=obj.get("imageUrl"), groups=groups, @@ -86,24 +90,26 @@ def to_dict(self) -> dict: "groups": [group.to_dict() for group in self.groups], } - def get_current_event(self, group_id: str, at: datetime.datetime) -> dict: + def get_current_event( + self, group_id: str, at: datetime.datetime + ) -> Interval | None: at = at.astimezone(utc) for group in self.groups: if group.id == group_id: for interval in group.intervals: - if interval.startTime <= at <= interval.endTime: - return interval.to_dict() - return {} + if interval.startTime <= at and at <= interval.endTime: + return interval + return None - def between( + def intersect( self, group_id: str, start: datetime.datetime, end: datetime.datetime - ) -> list[dict]: + ) -> list[Interval]: start = start.astimezone(utc) end = end.astimezone(utc) res = [] for group in self.groups: if group.id == group_id: for interval in group.intervals: - if start <= interval.startTime and interval.endTime <= end: - res.append(interval.to_dict()) + if interval.startTime <= end and start <= interval.endTime: + res.append(interval) return res