Skip to content

Commit

Permalink
Add SAS token renewal
Browse files Browse the repository at this point in the history
  • Loading branch information
KaSroka committed Sep 9, 2024
1 parent 3f88f91 commit bf1e8c5
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 36 deletions.
37 changes: 7 additions & 30 deletions toshiba_ac/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
ToshibaAcStatus,
ToshibaAcSwingMode,
)
from toshiba_ac.utils import async_sleep_until_next_multiply_of_minutes, pretty_enum_name
from toshiba_ac.utils import async_sleep_until_next_multiply_of_minutes, pretty_enum_name, ToshibaAcCallback
from toshiba_ac.utils.amqp_api import ToshibaAcAmqpApi, JSONSerializable
from toshiba_ac.utils.http_api import ToshibaAcHttpApi

Expand All @@ -44,33 +44,8 @@ class ToshibaAcDeviceError(Exception):
pass


class ToshibaAcDeviceCallback:
def __init__(self) -> None:
self.callbacks: t.List[t.Callable[[ToshibaAcDevice], t.Optional[t.Awaitable[None]]]] = []

def add(self, callback: t.Callable[[ToshibaAcDevice], t.Optional[t.Awaitable[None]]]) -> bool:
if callback not in self.callbacks:
self.callbacks.append(callback)
return True

return False

def remove(self, callback: t.Callable[[ToshibaAcDevice], t.Optional[t.Awaitable[None]]]) -> bool:
if callback in self.callbacks:
self.callbacks.remove(callback)
return True

return False

async def __call__(self, dev: ToshibaAcDevice) -> None:
for callback in self.callbacks:
asyncs = []
if asyncio.iscoroutinefunction(callback):
asyncs.append(t.cast(t.Awaitable[None], callback(dev)))
else:
callback(dev)

await asyncio.gather(*asyncs)
class ToshibaAcDeviceCallback(ToshibaAcCallback["ToshibaAcDevice"]):
pass


class ToshibaAcDevice:
Expand Down Expand Up @@ -105,6 +80,7 @@ def __init__(
self._on_state_changed_callback = ToshibaAcDeviceCallback()
self._on_energy_consumption_changed_callback = ToshibaAcDeviceCallback()
self._ac_energy_consumption: t.Optional[ToshibaAcDeviceEnergyConsumption] = None
self.periodic_reload_state_task: t.Optional[asyncio.Task[None]] = None

logger.debug(f"[{self.name}] {self.supported}")

Expand All @@ -113,8 +89,9 @@ async def connect(self) -> None:
self.periodic_reload_state_task = asyncio.get_running_loop().create_task(self.periodic_state_reload())

async def shutdown(self) -> None:
self.periodic_reload_state_task.cancel()
await self.periodic_reload_state_task
if self.periodic_reload_state_task:
self.periodic_reload_state_task.cancel()
await self.periodic_reload_state_task

async def load_additional_device_info(self) -> None:
additional_info = await self.http_api.get_device_additional_info(self.ac_id)
Expand Down
21 changes: 19 additions & 2 deletions toshiba_ac/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import typing as t

from toshiba_ac.device import ToshibaAcDevice
from toshiba_ac.utils import async_sleep_until_next_multiply_of_minutes
from toshiba_ac.utils import async_sleep_until_next_multiply_of_minutes, ToshibaAcCallback
from toshiba_ac.utils.amqp_api import ToshibaAcAmqpApi, JSONSerializable
from toshiba_ac.utils.http_api import ToshibaAcHttpApi

Expand All @@ -28,6 +28,10 @@ class ToshibaAcDeviceManagerError(Exception):
pass


class ToshibaAcSasTokenUpdatedCallback(ToshibaAcCallback[str]):
pass


class ToshibaAcDeviceManager:
FETCH_ENERGY_CONSUMPTION_PERIOD_MINUTES = 10

Expand All @@ -49,6 +53,7 @@ def __init__(
self.periodic_fetch_energy_consumption_task: t.Optional[asyncio.Task[None]] = None
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
self._on_sas_token_updated_callback = ToshibaAcSasTokenUpdatedCallback()

async def connect(self) -> str:
try:
Expand All @@ -61,7 +66,7 @@ async def connect(self) -> str:
self.sas_token = await self.http_api.register_client(self.device_id)

if not self.amqp_api:
self.amqp_api = ToshibaAcAmqpApi(self.sas_token)
self.amqp_api = ToshibaAcAmqpApi(self.sas_token, self.renew_sas_token)
self.amqp_api.register_command_handler("CMD_FCU_FROM_AC", self.handle_cmd_fcu_from_ac)
self.amqp_api.register_command_handler("CMD_HEARTBEAT", self.handle_cmd_heartbeat)
await self.amqp_api.connect()
Expand Down Expand Up @@ -195,6 +200,14 @@ async def get_devices(self) -> t.List[ToshibaAcDevice]:

return list(self.devices.values())

async def renew_sas_token(self) -> str:
if self.http_api:
self.sas_token = await self.http_api.register_client(self.device_id)
await self.on_sas_token_updated_callback(self.sas_token)
return self.sas_token

raise ToshibaAcDeviceManagerError("Not connected")

def handle_cmd_fcu_from_ac(
self,
source_id: str,
Expand All @@ -214,3 +227,7 @@ def handle_cmd_heartbeat(
timestamp: str,
) -> None:
asyncio.run_coroutine_threadsafe(self.devices[source_id].handle_cmd_heartbeat(payload), self.loop).result()

@property
def on_sas_token_updated_callback(self) -> ToshibaAcSasTokenUpdatedCallback:
return self._on_sas_token_updated_callback
33 changes: 33 additions & 0 deletions toshiba_ac/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,36 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return wrapper

return decorator


T = t.TypeVar("T") # Generic type variable for devices


class ToshibaAcCallback(t.Generic[T]):
def __init__(self) -> None:
self.callbacks: t.List[t.Callable[[T], t.Optional[t.Awaitable[None]]]] = []

def add(self, callback: t.Callable[[T], t.Optional[t.Awaitable[None]]]) -> bool:
if callback not in self.callbacks:
self.callbacks.append(callback)
return True

return False

def remove(self, callback: t.Callable[[T], t.Optional[t.Awaitable[None]]]) -> bool:
if callback in self.callbacks:
self.callbacks.remove(callback)
return True

return False

async def __call__(self, device: T) -> None:
asyncs = []

for callback in self.callbacks:
if asyncio.iscoroutinefunction(callback):
asyncs.append(t.cast(t.Awaitable[None], callback(device)))
else:
callback(device)

await asyncio.gather(*asyncs)
9 changes: 8 additions & 1 deletion toshiba_ac/utils/amqp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class ToshibaAcAmqpApi:
COMMANDS = ["CMD_FCU_FROM_AC", "CMD_HEARTBEAT"]
_HANDLER_TYPE = t.Callable[[str, str, list[JSONSerializable], dict[str, JSONSerializable], str], None]

def __init__(self, sas_token: str) -> None:
def __init__(self, sas_token: str, new_sas_token_required_callback: t.Callable[[], t.Awaitable[str]]) -> None:
self.sas_token = sas_token
self.handlers: t.Dict[str, ToshibaAcAmqpApi._HANDLER_TYPE] = {}

self.device = IoTHubDeviceClient.create_from_sastoken(self.sas_token)
self.device.on_method_request_received = self.method_request_received
self.device.on_new_sastoken_required = self.new_sas_token_required # type: ignore
self.on_new_sastoken_required_callback = new_sas_token_required_callback

async def connect(self) -> None:
await self.device.connect()
Expand All @@ -46,6 +48,11 @@ def register_command_handler(self, command: str, handler: ToshibaAcAmqpApi._HAND
raise AttributeError(f'Unknown command: {command}, should be one of {" ".join(self.COMMANDS)}')
self.handlers[command] = handler

async def new_sas_token_required(self) -> None:
logger.info(f"SAS token is about to expire")
new_token = await self.on_new_sastoken_required_callback()
await self.device.update_sastoken(new_token)

async def method_request_received(self, method_data: MethodRequest) -> None:
if method_data.name != "smmobile":
return logger.info(f"Unknown method name: {method_data.name} full data: {method_data.payload}")
Expand Down
6 changes: 3 additions & 3 deletions toshiba_ac/utils/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(self, username: str, password: str) -> None:
self.consumer_id: t.Optional[str] = None
self.session: t.Optional[aiohttp.ClientSession] = None

@retry_with_timeout(timeout=5, retries=3, backoff=10)
@retry_on_exception(exceptions=ToshibaAcHttpApiError, retries=3, backoff=10)
@retry_with_timeout(timeout=5, retries=3, backoff=60)
@retry_on_exception(exceptions=ToshibaAcHttpApiError, retries=3, backoff=60)
async def request_api(
self,
path: str,
Expand Down Expand Up @@ -107,7 +107,7 @@ async def request_api(
return json["ResObj"]
else:
if json["StatusCode"] == "InvalidUserNameorPassword":
err_type = ToshibaAcHttpApiAuthError(json["Message"])
raise ToshibaAcHttpApiAuthError(json["Message"])

raise ToshibaAcHttpApiError(json["Message"])

Expand Down

0 comments on commit bf1e8c5

Please sign in to comment.