Skip to content

Commit 10aee19

Browse files
committed
feat: Support fetching of multiple accessories at once
1 parent f01ad6b commit 10aee19

File tree

2 files changed

+175
-44
lines changed

2 files changed

+175
-44
lines changed

findmy/reports/account.py

Lines changed: 102 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
TYPE_CHECKING,
1616
Any,
1717
Callable,
18-
Sequence,
1918
TypedDict,
2019
TypeVar,
2120
cast,
@@ -49,6 +48,8 @@
4948
)
5049

5150
if TYPE_CHECKING:
51+
from collections.abc import Sequence
52+
5253
from findmy.accessory import RollingKeyPairSource
5354
from findmy.keys import HasHashedPublicKey
5455
from findmy.util.types import MaybeCoro
@@ -248,13 +249,28 @@ def fetch_reports(
248249
date_to: datetime | None,
249250
) -> MaybeCoro[list[LocationReport]]: ...
250251

252+
@overload
253+
def fetch_reports(
254+
self,
255+
keys: Sequence[RollingKeyPairSource],
256+
date_from: datetime,
257+
date_to: datetime | None,
258+
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...
259+
251260
@abstractmethod
252261
def fetch_reports(
253262
self,
254-
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
263+
keys: HasHashedPublicKey
264+
| Sequence[HasHashedPublicKey]
265+
| RollingKeyPairSource
266+
| Sequence[RollingKeyPairSource],
255267
date_from: datetime,
256268
date_to: datetime | None,
257-
) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]:
269+
) -> MaybeCoro[
270+
list[LocationReport]
271+
| dict[HasHashedPublicKey, list[LocationReport]]
272+
| dict[RollingKeyPairSource, list[LocationReport]]
273+
]:
258274
"""
259275
Fetch location reports for `HasHashedPublicKey`s between `date_from` and `date_end`.
260276
@@ -286,12 +302,27 @@ def fetch_last_reports(
286302
hours: int = 7 * 24,
287303
) -> MaybeCoro[list[LocationReport]]: ...
288304

305+
@overload
306+
@abstractmethod
307+
def fetch_last_reports(
308+
self,
309+
keys: Sequence[RollingKeyPairSource],
310+
hours: int = 7 * 24,
311+
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...
312+
289313
@abstractmethod
290314
def fetch_last_reports(
291315
self,
292-
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
316+
keys: HasHashedPublicKey
317+
| Sequence[HasHashedPublicKey]
318+
| RollingKeyPairSource
319+
| Sequence[RollingKeyPairSource],
293320
hours: int = 7 * 24,
294-
) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]:
321+
) -> MaybeCoro[
322+
list[LocationReport]
323+
| dict[HasHashedPublicKey, list[LocationReport]]
324+
| dict[RollingKeyPairSource, list[LocationReport]]
325+
]:
295326
"""
296327
Fetch location reports for a sequence of `HasHashedPublicKey`s for the last `hours` hours.
297328
@@ -641,14 +672,29 @@ async def fetch_reports(
641672
date_to: datetime | None,
642673
) -> list[LocationReport]: ...
643674

675+
@overload
676+
async def fetch_reports(
677+
self,
678+
keys: Sequence[RollingKeyPairSource],
679+
date_from: datetime,
680+
date_to: datetime | None,
681+
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
682+
644683
@require_login_state(LoginState.LOGGED_IN)
645684
@override
646685
async def fetch_reports(
647686
self,
648-
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
687+
keys: HasHashedPublicKey
688+
| Sequence[HasHashedPublicKey]
689+
| RollingKeyPairSource
690+
| Sequence[RollingKeyPairSource],
649691
date_from: datetime,
650692
date_to: datetime | None,
651-
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
693+
) -> (
694+
list[LocationReport]
695+
| dict[HasHashedPublicKey, list[LocationReport]]
696+
| dict[RollingKeyPairSource, list[LocationReport]]
697+
):
652698
"""See `BaseAppleAccount.fetch_reports`."""
653699
date_to = date_to or datetime.now().astimezone()
654700

@@ -679,13 +725,27 @@ async def fetch_last_reports(
679725
hours: int = 7 * 24,
680726
) -> list[LocationReport]: ...
681727

728+
@overload
729+
async def fetch_last_reports(
730+
self,
731+
keys: Sequence[RollingKeyPairSource],
732+
hours: int = 7 * 24,
733+
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
734+
682735
@require_login_state(LoginState.LOGGED_IN)
683736
@override
684737
async def fetch_last_reports(
685738
self,
686-
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
739+
keys: HasHashedPublicKey
740+
| Sequence[HasHashedPublicKey]
741+
| RollingKeyPairSource
742+
| Sequence[RollingKeyPairSource],
687743
hours: int = 7 * 24,
688-
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
744+
) -> (
745+
list[LocationReport]
746+
| dict[HasHashedPublicKey, list[LocationReport]]
747+
| dict[RollingKeyPairSource, list[LocationReport]]
748+
):
689749
"""See `BaseAppleAccount.fetch_last_reports`."""
690750
end = datetime.now(tz=timezone.utc)
691751
start = end - timedelta(hours=hours)
@@ -1041,13 +1101,28 @@ def fetch_reports(
10411101
date_to: datetime | None,
10421102
) -> list[LocationReport]: ...
10431103

1104+
@overload
1105+
def fetch_reports(
1106+
self,
1107+
keys: Sequence[RollingKeyPairSource],
1108+
date_from: datetime,
1109+
date_to: datetime | None,
1110+
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
1111+
10441112
@override
10451113
def fetch_reports(
10461114
self,
1047-
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
1115+
keys: HasHashedPublicKey
1116+
| Sequence[HasHashedPublicKey]
1117+
| RollingKeyPairSource
1118+
| Sequence[RollingKeyPairSource],
10481119
date_from: datetime,
10491120
date_to: datetime | None,
1050-
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
1121+
) -> (
1122+
list[LocationReport]
1123+
| dict[HasHashedPublicKey, list[LocationReport]]
1124+
| dict[RollingKeyPairSource, list[LocationReport]]
1125+
):
10511126
"""See `AsyncAppleAccount.fetch_reports`."""
10521127
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
10531128
return self._evt_loop.run_until_complete(coro)
@@ -1073,12 +1148,26 @@ def fetch_last_reports(
10731148
hours: int = 7 * 24,
10741149
) -> list[LocationReport]: ...
10751150

1151+
@overload
1152+
def fetch_last_reports(
1153+
self,
1154+
keys: Sequence[RollingKeyPairSource],
1155+
hours: int = 7 * 24,
1156+
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
1157+
10761158
@override
10771159
def fetch_last_reports(
10781160
self,
1079-
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
1161+
keys: HasHashedPublicKey
1162+
| Sequence[HasHashedPublicKey]
1163+
| RollingKeyPairSource
1164+
| Sequence[RollingKeyPairSource],
10801165
hours: int = 7 * 24,
1081-
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
1166+
) -> (
1167+
list[LocationReport]
1168+
| dict[HasHashedPublicKey, list[LocationReport]]
1169+
| dict[RollingKeyPairSource, list[LocationReport]]
1170+
):
10821171
"""See `AsyncAppleAccount.fetch_last_reports`."""
10831172
coro = self._asyncacc.fetch_last_reports(keys, hours)
10841173
return self._evt_loop.run_until_complete(coro)

findmy/reports/reports.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import hashlib
77
import logging
88
import struct
9+
from collections import defaultdict
910
from datetime import datetime, timedelta, timezone
10-
from typing import TYPE_CHECKING, overload
11+
from typing import TYPE_CHECKING, cast, overload
1112

1213
from cryptography.hazmat.backends import default_backend
1314
from cryptography.hazmat.primitives.asymmetric import ec
@@ -260,12 +261,27 @@ async def fetch_reports(
260261
device: RollingKeyPairSource,
261262
) -> list[LocationReport]: ...
262263

264+
@overload
265+
async def fetch_reports(
266+
self,
267+
date_from: datetime,
268+
date_to: datetime,
269+
device: Sequence[RollingKeyPairSource],
270+
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...
271+
263272
async def fetch_reports(
264273
self,
265274
date_from: datetime,
266275
date_to: datetime,
267-
device: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
268-
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
276+
device: HasHashedPublicKey
277+
| Sequence[HasHashedPublicKey]
278+
| RollingKeyPairSource
279+
| Sequence[RollingKeyPairSource],
280+
) -> (
281+
list[LocationReport]
282+
| dict[HasHashedPublicKey, list[LocationReport]]
283+
| dict[RollingKeyPairSource, list[LocationReport]]
284+
):
269285
"""
270286
Fetch location reports for a certain device.
271287
@@ -276,45 +292,71 @@ async def fetch_reports(
276292
When ``device`` is a :class:`.RollingKeyPairSource`, it will return a list of
277293
location reports corresponding to that source.
278294
"""
279-
# single key
295+
key_devs: (
296+
dict[HasHashedPublicKey, HasHashedPublicKey]
297+
| dict[HasHashedPublicKey, RollingKeyPairSource]
298+
) = {}
280299
if isinstance(device, HasHashedPublicKey):
281-
return await self._fetch_reports(date_from, date_to, [device])
282-
283-
# key generator
284-
# add 12h margin to the generator
285-
if isinstance(device, RollingKeyPairSource):
286-
keys = list(
287-
device.keys_between(
300+
# single key
301+
key_devs = {device: device}
302+
elif isinstance(device, list) and all(isinstance(x, HasHashedPublicKey) for x in device):
303+
# multiple static keys
304+
device = cast(list[HasHashedPublicKey], device)
305+
key_devs = {key: key for key in device}
306+
elif isinstance(device, RollingKeyPairSource):
307+
# key generator
308+
# add 12h margin to the generator
309+
key_devs = {
310+
key: device
311+
for key in device.keys_between(
312+
date_from - timedelta(hours=12),
313+
date_to + timedelta(hours=12),
314+
)
315+
}
316+
elif isinstance(device, list) and all(isinstance(x, RollingKeyPairSource) for x in device):
317+
# multiple key generators
318+
# add 12h margin to each generator
319+
device = cast(list[RollingKeyPairSource], device)
320+
key_devs = {
321+
key: dev
322+
for dev in device
323+
for key in dev.keys_between(
288324
date_from - timedelta(hours=12),
289325
date_to + timedelta(hours=12),
290-
),
291-
)
326+
)
327+
}
292328
else:
293-
keys = device
329+
msg = "Unknown device type: %s"
330+
raise ValueError(msg, type(device))
294331

295332
# sequence of keys (fetch 256 max at a time)
296-
reports: list[LocationReport] = []
333+
key_reports: dict[HasHashedPublicKey, list[LocationReport]] = {}
334+
keys = list(key_devs.keys())
297335
for key_offset in range(0, len(keys), 256):
298-
chunk = keys[key_offset : key_offset + 256]
299-
reports.extend(await self._fetch_reports(date_from, date_to, chunk))
300-
301-
if isinstance(device, RollingKeyPairSource):
302-
return reports
303-
304-
res: dict[HasHashedPublicKey, list[LocationReport]] = {key: [] for key in keys}
305-
for report in reports:
306-
for key in res:
307-
if key.hashed_adv_key_bytes == report.hashed_adv_key_bytes:
308-
res[key].append(report)
309-
break
310-
return res
336+
chunk_keys = keys[key_offset : key_offset + 256]
337+
chunk_reports = await self._fetch_reports(date_from, date_to, chunk_keys)
338+
key_reports |= chunk_reports
339+
340+
# combine (key -> list[report]) and (key -> device) into (device -> list[report])
341+
device_reports = defaultdict(list)
342+
for key, reports in key_reports.items():
343+
device_reports[key_devs[key]].extend(reports)
344+
for dev in device_reports:
345+
device_reports[dev] = sorted(device_reports[dev])
346+
347+
# result
348+
if isinstance(device, (HasHashedPublicKey, RollingKeyPairSource)):
349+
# single key or generator
350+
return device_reports[device]
351+
# multiple static keys or key generators
352+
return device_reports
311353

312354
async def _fetch_reports(
313355
self,
314356
date_from: datetime,
315357
date_to: datetime,
316358
keys: Sequence[HasHashedPublicKey],
317-
) -> list[LocationReport]:
359+
) -> dict[HasHashedPublicKey, list[LocationReport]]:
318360
logging.debug("Fetching reports for %s keys", len(keys))
319361

320362
# lock requested time range to the past 7 days, +- 12 hours, then filter the response.
@@ -327,7 +369,7 @@ async def _fetch_reports(
327369
data = await self._account.fetch_raw_reports(start_date, end_date, ids)
328370

329371
id_to_key: dict[bytes, HasHashedPublicKey] = {key.hashed_adv_key_bytes: key for key in keys}
330-
reports: list[LocationReport] = []
372+
reports: dict[HasHashedPublicKey, list[LocationReport]] = defaultdict(list)
331373
for report in data.get("results", []):
332374
payload = base64.b64decode(report["payload"])
333375
hashed_adv_key = base64.b64decode(report["id"])
@@ -347,6 +389,6 @@ async def _fetch_reports(
347389
if loc_report.timestamp < date_from or loc_report.timestamp > date_to:
348390
continue
349391

350-
reports.append(loc_report)
392+
reports[key].append(loc_report)
351393

352394
return reports

0 commit comments

Comments
 (0)