Skip to content

Commit

Permalink
Merge pull request #81 from lemonyte/async-routeplan-builder
Browse files Browse the repository at this point in the history
  • Loading branch information
lemonyte authored Jun 7, 2024
2 parents 80a6a73 + fb16f14 commit 509b103
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 34 deletions.
32 changes: 20 additions & 12 deletions src/ferry_planner/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,46 +253,51 @@ class RoutePlanBuilder:
def __init__(self, connection_db: ConnectionDB, /) -> None:
self._connection_db = connection_db

def make_route_plans(
async def make_route_plans(
self,
*,
routes: Iterable[Route],
options: RoutePlansOptions,
schedule_getter: ScheduleGetter,
) -> Iterator[RoutePlan]:
) -> Sequence[RoutePlan]:
route_plans = []
for route in routes:
yield from self._add_plan_segment(
await self._add_plan_segment(
route_plans=route_plans,
route=route,
destination_index=1,
start_time=options.date.replace(hour=0, minute=0, second=0, microsecond=0),
options=options,
schedule_getter=schedule_getter,
)
return route_plans

def _add_plan_segment( # noqa: PLR0913
async def _add_plan_segment( # noqa: PLR0913
self,
*,
route_plans: list,
route: Route,
destination_index: int,
start_time: datetime,
options: RoutePlansOptions,
schedule_getter: ScheduleGetter,
segments: list[RoutePlanSegment] | None = None,
) -> Generator[RoutePlan, None, bool]:
) -> bool:
if segments is None:
segments = []
result = False
try:
if destination_index == len(route):
if not segments: # empty list?
return False # can we be here?
yield RoutePlan.from_segments(segments)
route_plans.append(RoutePlan.from_segments(segments))
return True
origin = route[destination_index - 1]
destination = route[destination_index]
connection = self._connection_db.from_to_location(origin, destination)
if isinstance(connection, FerryConnection):
result = yield from self._add_ferry_connection(
result = await self._add_ferry_connection(
route_plans=route_plans,
route=route,
destination_index=destination_index,
segments=segments,
Expand All @@ -318,7 +323,8 @@ def _add_plan_segment( # noqa: PLR0913
),
)
segments.append(RoutePlanSegment(connection=connection, times=times))
result = yield from self._add_plan_segment(
result = await self._add_plan_segment(
route_plans=route_plans,
route=route,
destination_index=destination_index + 1,
segments=segments,
Expand All @@ -331,21 +337,22 @@ def _add_plan_segment( # noqa: PLR0913
del segments[delete_start:]
return result

def _add_ferry_connection( # noqa: C901, PLR0912, PLR0913
async def _add_ferry_connection( # noqa: C901, PLR0912, PLR0913
self,
*,
route_plans: list,
route: Route,
destination_index: int,
segments: list[RoutePlanSegment],
start_time: datetime,
options: RoutePlansOptions,
connection: FerryConnection,
schedule_getter: ScheduleGetter,
) -> Generator[RoutePlan, None, bool]:
) -> bool:
result = False
depature_terminal = connection.origin
day = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
schedule = schedule_getter(connection.origin.id, connection.destination.id, date=day)
schedule = await schedule_getter(connection.origin.id, connection.destination.id, date=day)
if not schedule:
return False
for sailing in schedule.sailings:
Expand Down Expand Up @@ -406,7 +413,8 @@ def _add_ferry_connection( # noqa: C901, PLR0912, PLR0913
schedule_url=schedule.url,
),
)
recursion_result = yield from self._add_plan_segment(
recursion_result = await self._add_plan_segment(
route_plans=route_plans,
route=route,
destination_index=destination_index + 1,
segments=segments,
Expand Down
25 changes: 5 additions & 20 deletions src/ferry_planner/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class FerrySchedule(BaseModel):


class ScheduleGetter(Protocol):
def __call__(
async def __call__(
self,
origin_id: LocationId,
destination_id: LocationId,
Expand Down Expand Up @@ -149,7 +149,7 @@ def _get_filepath(
) -> Path:
return self.cache_dir / f"{origin_id}-{destination_id}" / f"{date.date()}.json"

def get(
async def get(
self,
origin_id: LocationId,
destination_id: LocationId,
Expand All @@ -165,7 +165,7 @@ def get(
schedule = FerrySchedule.model_validate_json(filepath.read_text(encoding="utf-8"))
self._mem_cache[filepath] = schedule
return schedule
schedule = self.download_schedule(origin_id, destination_id, date=date)
schedule = await self.download_schedule(origin_id, destination_id, date=date)
if schedule:
self.put(schedule)
return schedule
Expand All @@ -182,22 +182,7 @@ def put(self, schedule: FerrySchedule, /) -> None:
dirpath.mkdir(mode=0o755, parents=True, exist_ok=True)
filepath.write_text(schedule.model_dump_json(indent=4, exclude_none=True), encoding="utf-8")

def download_schedule(
self,
origin_id: LocationId,
destination_id: LocationId,
/,
*,
date: datetime,
) -> FerrySchedule | None:
coro = self.download_schedule_async(origin_id, destination_id, date=date)
try:
loop = asyncio.get_running_loop()
return loop.run_until_complete(coro)
except RuntimeError:
return asyncio.run(coro)

async def download_schedule_async(
async def download_schedule(
self,
origin_id: LocationId,
destination_id: LocationId,
Expand Down Expand Up @@ -261,7 +246,7 @@ async def _download_and_save_schedule(
*,
date: datetime,
) -> bool:
schedule = await self.download_schedule_async(
schedule = await self.download_schedule(
origin_id,
destination_id,
date=date,
Expand Down
4 changes: 2 additions & 2 deletions src/ferry_planner/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def api_locations() -> Mapping[LocationId, Location]:
responses={404: {"model": Mapping[Literal["detail"], str]}},
)
async def api_schedule(options: ScheduleOptions) -> FerrySchedule | Response:
schedule = schedule_db.get(options.origin, options.destination, date=options.date)
schedule = await schedule_db.get(options.origin, options.destination, date=options.date)
if schedule is None:
return Response(status_code=status.HTTP_404_NOT_FOUND, content={"detail": "Schedule not found"})
return schedule
Expand All @@ -97,7 +97,7 @@ async def api_routeplans(options: RoutePlansOptions) -> Sequence[RoutePlan]:
destination = location_db.by_id(options.destination)
routes = route_builder.find_routes(origin=origin, destination=destination)
route_plans = list(
route_plan_builder.make_route_plans(
await route_plan_builder.make_route_plans(
routes=routes,
options=options,
schedule_getter=schedule_db.get,
Expand Down

0 comments on commit 509b103

Please sign in to comment.