diff --git a/src/ferry_planner/route.py b/src/ferry_planner/route.py index 5637c56..a0c278a 100644 --- a/src/ferry_planner/route.py +++ b/src/ferry_planner/route.py @@ -253,32 +253,36 @@ class RoutePlanBuilder: def __init__(self, connection_db: ConnectionDB, /) -> None: self._connection_db = connection_db - def make_route_plans( + async def make_route_plans( self, *, routes: Iterable[Route], options: RoutePlansOptions, schedule_getter: ScheduleGetter, - ) -> Iterator[RoutePlan]: + ) -> Sequence[RoutePlan]: + route_plans = [] for route in routes: - yield from self._add_plan_segment( + await self._add_plan_segment( + route_plans=route_plans, route=route, destination_index=1, start_time=options.date.replace(hour=0, minute=0, second=0, microsecond=0), options=options, schedule_getter=schedule_getter, ) + return route_plans - def _add_plan_segment( # noqa: PLR0913 + async def _add_plan_segment( # noqa: PLR0913 self, *, + route_plans: list, route: Route, destination_index: int, start_time: datetime, options: RoutePlansOptions, schedule_getter: ScheduleGetter, segments: list[RoutePlanSegment] | None = None, - ) -> Generator[RoutePlan, None, bool]: + ) -> bool: if segments is None: segments = [] result = False @@ -286,13 +290,14 @@ def _add_plan_segment( # noqa: PLR0913 if destination_index == len(route): if not segments: # empty list? return False # can we be here? - yield RoutePlan.from_segments(segments) + route_plans.append(RoutePlan.from_segments(segments)) return True origin = route[destination_index - 1] destination = route[destination_index] connection = self._connection_db.from_to_location(origin, destination) if isinstance(connection, FerryConnection): - result = yield from self._add_ferry_connection( + result = await self._add_ferry_connection( + route_plans=route_plans, route=route, destination_index=destination_index, segments=segments, @@ -318,7 +323,8 @@ def _add_plan_segment( # noqa: PLR0913 ), ) segments.append(RoutePlanSegment(connection=connection, times=times)) - result = yield from self._add_plan_segment( + result = await self._add_plan_segment( + route_plans=route_plans, route=route, destination_index=destination_index + 1, segments=segments, @@ -331,9 +337,10 @@ def _add_plan_segment( # noqa: PLR0913 del segments[delete_start:] return result - def _add_ferry_connection( # noqa: C901, PLR0912, PLR0913 + async def _add_ferry_connection( # noqa: C901, PLR0912, PLR0913 self, *, + route_plans: list, route: Route, destination_index: int, segments: list[RoutePlanSegment], @@ -341,11 +348,11 @@ def _add_ferry_connection( # noqa: C901, PLR0912, PLR0913 options: RoutePlansOptions, connection: FerryConnection, schedule_getter: ScheduleGetter, - ) -> Generator[RoutePlan, None, bool]: + ) -> bool: result = False depature_terminal = connection.origin day = start_time.replace(hour=0, minute=0, second=0, microsecond=0) - schedule = schedule_getter(connection.origin.id, connection.destination.id, date=day) + schedule = await schedule_getter(connection.origin.id, connection.destination.id, date=day) if not schedule: return False for sailing in schedule.sailings: @@ -406,7 +413,8 @@ def _add_ferry_connection( # noqa: C901, PLR0912, PLR0913 schedule_url=schedule.url, ), ) - recursion_result = yield from self._add_plan_segment( + recursion_result = await self._add_plan_segment( + route_plans=route_plans, route=route, destination_index=destination_index + 1, segments=segments, diff --git a/src/ferry_planner/schedule.py b/src/ferry_planner/schedule.py index 8f3ca16..c9d048c 100644 --- a/src/ferry_planner/schedule.py +++ b/src/ferry_planner/schedule.py @@ -71,7 +71,7 @@ class FerrySchedule(BaseModel): class ScheduleGetter(Protocol): - def __call__( + async def __call__( self, origin_id: LocationId, destination_id: LocationId, @@ -149,7 +149,7 @@ def _get_filepath( ) -> Path: return self.cache_dir / f"{origin_id}-{destination_id}" / f"{date.date()}.json" - def get( + async def get( self, origin_id: LocationId, destination_id: LocationId, @@ -165,7 +165,7 @@ def get( schedule = FerrySchedule.model_validate_json(filepath.read_text(encoding="utf-8")) self._mem_cache[filepath] = schedule return schedule - schedule = self.download_schedule(origin_id, destination_id, date=date) + schedule = await self.download_schedule(origin_id, destination_id, date=date) if schedule: self.put(schedule) return schedule @@ -182,22 +182,7 @@ def put(self, schedule: FerrySchedule, /) -> None: dirpath.mkdir(mode=0o755, parents=True, exist_ok=True) filepath.write_text(schedule.model_dump_json(indent=4, exclude_none=True), encoding="utf-8") - def download_schedule( - self, - origin_id: LocationId, - destination_id: LocationId, - /, - *, - date: datetime, - ) -> FerrySchedule | None: - coro = self.download_schedule_async(origin_id, destination_id, date=date) - try: - loop = asyncio.get_running_loop() - return loop.run_until_complete(coro) - except RuntimeError: - return asyncio.run(coro) - - async def download_schedule_async( + async def download_schedule( self, origin_id: LocationId, destination_id: LocationId, @@ -261,7 +246,7 @@ async def _download_and_save_schedule( *, date: datetime, ) -> bool: - schedule = await self.download_schedule_async( + schedule = await self.download_schedule( origin_id, destination_id, date=date, diff --git a/src/ferry_planner/server.py b/src/ferry_planner/server.py index 15e1a4c..4b3fa98 100644 --- a/src/ferry_planner/server.py +++ b/src/ferry_planner/server.py @@ -85,7 +85,7 @@ async def api_locations() -> Mapping[LocationId, Location]: responses={404: {"model": Mapping[Literal["detail"], str]}}, ) async def api_schedule(options: ScheduleOptions) -> FerrySchedule | Response: - schedule = schedule_db.get(options.origin, options.destination, date=options.date) + schedule = await schedule_db.get(options.origin, options.destination, date=options.date) if schedule is None: return Response(status_code=status.HTTP_404_NOT_FOUND, content={"detail": "Schedule not found"}) return schedule @@ -97,7 +97,7 @@ async def api_routeplans(options: RoutePlansOptions) -> Sequence[RoutePlan]: destination = location_db.by_id(options.destination) routes = route_builder.find_routes(origin=origin, destination=destination) route_plans = list( - route_plan_builder.make_route_plans( + await route_plan_builder.make_route_plans( routes=routes, options=options, schedule_getter=schedule_db.get,