Skip to content

Commit

Permalink
fix nocolor and add stations list
Browse files Browse the repository at this point in the history
bump version
  • Loading branch information
lzgirlcat committed Sep 29, 2024
1 parent 6233d64 commit 6332353
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 14 deletions.
2 changes: 1 addition & 1 deletion koleo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_stations(self) -> list[ExtendedStationInfo]:

def find_station(self, query: str, language: str = "pl") -> list[ExtendedStationInfo]:
# https://koleo.pl/ls?q=tere&language=pl
return self._get_json("/ls", query={"q": query, "language": language})
return self._get_json("/ls", params={"q": query, "language": language})["stations"]

def get_station_by_slug(self, slug: str) -> ExtendedBaseStationInfo:
# https://koleo.pl/api/v2/main/stations/by_slug/inowroclaw
Expand Down
55 changes: 43 additions & 12 deletions koleo/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from argparse import ArgumentParser
from datetime import datetime, timedelta

Expand All @@ -22,7 +23,16 @@ def __init__(
) -> None:
self._client = client
self._storage = storage
self.console = Console(color_system="standard", no_color=no_color)
self.no_color = no_color
self.console = Console(color_system="standard", no_color=no_color, highlight=False)

def print(self, text, *args, **kwargs):
if self.no_color:
result = re.sub(r'\[[^\]]*\]', '', text)
print(result)
else:
self.console.print(text, *args, **kwargs)


@property
def client(self) -> KoleoAPI:
Expand Down Expand Up @@ -59,7 +69,7 @@ def get_departures(self, station_id: int, date: datetime):
if datetime.fromisoformat(i["departure"]).timestamp() > date.timestamp() # type: ignore
]
table = self.trains_on_station_table(trains)
self.console.print(table)
self.print(table)
return table

def get_arrivals(self, station_id: int, date: datetime):
Expand All @@ -74,21 +84,32 @@ def get_arrivals(self, station_id: int, date: datetime):
if datetime.fromisoformat(i["arrival"]).timestamp() > date.timestamp() # type: ignore
]
table = self.trains_on_station_table(trains, type=2)
self.console.print(table)
self.print(table)
return table

def full_departures(self, station: str, date: datetime):
st = self.get_station(station)
station_info = f"[bold blue]{st["name"]}[/bold blue] ID: {st["id"]}"
self.console.print(station_info)
self.print(station_info)
self.get_departures(st["id"], date)

def full_arrivals(self, station: str, date: datetime):
st = self.get_station(station)
station_info = f"[bold blue]{st["name"]}[/bold blue] ID: {st["id"]}"
self.console.print(station_info)
self.print(station_info)
self.get_arrivals(st["id"], date)

def find_station(self, query: str | None):
if query:
stations = self.client.find_station(query)
else:
stations = (
self.storage.get_cache("stations") or
self.storage.set_cache("stations", self.client.get_stations())
)
for st in stations:
self.print(f"[bold blue]{st["name"]}[/bold blue] ID: {st["id"]}")

def train_info(self, brand: str, name: str, date: datetime):
brand = brand.upper().strip()
name = name.strip()
Expand All @@ -113,7 +134,7 @@ def train_info(self, brand: str, name: str, date: datetime):
train_id = train_calendars["train_calendars"][0]["date_train_map"][date.strftime("%Y-%m-%d")]
train_details = self.client.get_train(train_id)
brand = next(iter(i for i in brands if i["id"] == train_details["train"]["brand_id"]), {}).get("logo_text", "")
parts = [f"{brand} {train_details["train"]["train_full_name"]}"]
parts = [f"[red]{brand}[/red] [bold blue]{train_details["train"]["train_full_name"]}[/bold blue]"]
route_start = arr_dep_to_dt(train_details["stops"][0]["departure"])
route_end = arr_dep_to_dt(train_details["stops"][-1]["arrival"])
if route_end.hour < route_start.hour or (route_end.hour==route_start.hour and route_end.minute < route_end.minute):
Expand All @@ -134,8 +155,8 @@ def train_info(self, brand: str, name: str, date: datetime):
parts.append(f"[bold green] {start} - {keys[i]}:[/bold green] {vehicle_types[start]}")
start = keys[i]
parts.append(f"[bold green] {start} - {keys[-1]}:[/bold green] {vehicle_types[start]}")
self.console.print("\n".join(parts))
self.console.print(self.train_route_table(train_details))
self.print("\n".join(parts))
self.print(self.train_route_table(train_details))

def connections(self, start: str, end: str, date: datetime, brands: list[str], direct: bool = False, purchasable: bool = False):
start_station = self.get_station(start)
Expand All @@ -147,7 +168,7 @@ def connections(self, start: str, end: str, date: datetime, brands: list[str], d
else:
connection_brands = [i["id"] for i in api_brands if i["name"].lower().strip() in brands or i["logo_text"].lower().strip() in brands]
if not connection_brands:
self.console.print(f'[bold red]No brands match: "{', '.join(brands)}"[/bold red]')
self.print(f'[bold red]No brands match: "{', '.join(brands)}"[/bold red]')
exit(2)
connections = self.client.get_connections(
start_station["name_slug"],
Expand All @@ -168,7 +189,7 @@ def trains_on_station_table(self, trains: list[TrainOnStationInfo], type: int =
platform = convert_platform_number(train["platform"]) if train["platform"] else ""
position_info = f"{platform}/{train["track"]}" if train["track"] else platform
parts.append(
f"[bold green]{time[11:16]}[/bold green] {brand} {train["train_full_name"]}[purple] {train["stations"][0]["name"]} {position_info}[/purple]"
f"[bold green]{time[11:16]}[/bold green] [red]{brand}[/red] {train["train_full_name"]}[purple] {train["stations"][0]["name"]} {position_info}[/purple]"
)
return "\n".join(parts)

Expand All @@ -190,7 +211,7 @@ def get_station(self, station: str) -> ExtendedBaseStationInfo:
f"st-{slug}", self.client.get_station_by_slug(slug)
)
except self.client.errors.KoleoNotFound:
self.console.print(f'[bold red]Station not found: "{station}"[/bold red]')
self.print(f'[bold red]Station not found: "{station}"[/bold red]')
exit(2)

def main():
Expand Down Expand Up @@ -251,9 +272,18 @@ def main():
)
train_route.set_defaults(func=cli.train_info, pass_=["brand", "name", "date"])

stations = subparsers.add_parser("stations", aliases=["s", "find", "f", "stacje"], help="Allows you to find stations by their name")
stations.add_argument(
"query",
help="The station name",
default=None,
nargs="?",
)
stations.set_defaults(func=cli.find_station, pass_=["query"])

connections = subparsers.add_parser(
"connections",
aliases=["do", "z", "szukaj", "path", "find"],
aliases=["do", "z", "szukaj", "path"],
help="Allows you to search for connections from a to b",
)
connections.add_argument("start", help="The starting station", type=str)
Expand Down Expand Up @@ -291,6 +321,7 @@ def main():
client = KoleoAPI()
cli.client, cli.storage = client, storage
cli.console.no_color = args.nocolor
cli.no_color = args.nocolor
if hasattr(args, "station") and args.station is None:
args.station = storage.favourite_station
elif hasattr(args, "station") and args.save:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parse_requirements_file(path):

setuptools.setup(
name="koleo-cli",
version="0.2.137.2",
version="0.2.137.3",
description="Koleo CLI",
long_description=long_description(),
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 6332353

Please sign in to comment.