From a2ff313d9155f109e63863ab0a1d0e8d642bc388 Mon Sep 17 00:00:00 2001 From: George Shammas Date: Sat, 3 Dec 2022 02:04:22 -0500 Subject: [PATCH] Fix typing errors --- mtapi/mtapi.py | 8 +++-- mtaproto/feedresponse.py | 70 +++++++++++++++++++--------------------- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/mtapi/mtapi.py b/mtapi/mtapi.py index 3599ec1..5eafad2 100644 --- a/mtapi/mtapi.py +++ b/mtapi/mtapi.py @@ -1,4 +1,5 @@ -import urllib, contextlib, datetime, copy +import urllib.request, urllib.error +import contextlib, datetime, copy from collections import defaultdict from itertools import islice from operator import itemgetter @@ -8,10 +9,13 @@ import google.protobuf.message from mtaproto.feedresponse import FeedResponse, Trip, TripStop, TZ from mtapi._mtapithreader import _MtapiThreader +from typing import TypeAlias logger = logging.getLogger(__name__) -def distance(p1, p2): +point: TypeAlias = tuple[float, float] | list[float] + +def distance(p1: point, p2: point) -> float: return math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) class Mtapi(object): diff --git a/mtaproto/feedresponse.py b/mtaproto/feedresponse.py index 3316b5f..f36a18d 100644 --- a/mtaproto/feedresponse.py +++ b/mtaproto/feedresponse.py @@ -1,59 +1,57 @@ -from mtaproto import nyct_subway_pb2 -from pytz import timezone import datetime +from pytz import timezone + +from . import nyct_subway_pb2 +from . import gtfs_realtime_pb2 + TZ = timezone('US/Eastern') class FeedResponse(object): def __init__(self, response_string): - self._pb_data = nyct_subway_pb2.gtfs__realtime__pb2.FeedMessage() + gtfs_realtime_pb2.FeedMessage() + self._pb_data = gtfs_realtime_pb2.FeedMessage() self._pb_data.ParseFromString(response_string) - def __getattr__(self, name): - - if name == 'timestamp': - return datetime.datetime.fromtimestamp(self._pb_data.header.timestamp, TZ) - - - return getattr(self._pb_data, name) - + @property + def timestamp(self): + return datetime.datetime.fromtimestamp(self._pb_data.header.timestamp, TZ) + + @property + def entity(self): + return self._pb_data.entity + class Trip(object): - def __init__(self, pb_data): + def __init__(self, pb_data: gtfs_realtime_pb2.FeedEntity): self._pb_data = pb_data - def __getattr__(self, name): - - if name == 'direction': - return self._direction() - elif name == 'route_id': - if self._pb_data.trip_update.trip.route_id == 'GS': - return 'S' - else: - return self._pb_data.trip_update.trip.route_id - - - return getattr(self._pb_data, name) - - def _direction(self): + @property + def direction(self): trip_meta = self._pb_data.trip_update.trip.Extensions[nyct_subway_pb2.nyct_trip_descriptor] return nyct_subway_pb2.NyctTripDescriptor.Direction.Name(trip_meta.direction) + + @property + def route_id(self): + if self._pb_data.trip_update.trip.route_id == 'GS': + return 'S' + else: + return self._pb_data.trip_update.trip.route_id def is_valid(self): return bool(self._pb_data.trip_update) class TripStop(object): - def __init__(self, pb_data): + def __init__(self, pb_data: gtfs_realtime_pb2.TripUpdate.StopTimeUpdate): self._pb_data = pb_data - def __getattr__(self, name): - - if name == 'time': - raw_time = self._pb_data.arrival.time or self._pb_data.departure.time - return datetime.datetime.fromtimestamp(raw_time, TZ) - elif name == 'stop_id': - return str(self._pb_data.stop_id[:3]) - - return getattr(self._pb_data, name) + @property + def time(self): + raw_time = self._pb_data.arrival.time or self._pb_data.departure.time + return datetime.datetime.fromtimestamp(raw_time, TZ) + + @property + def stop_id(self): + return str(self._pb_data.stop_id[:3])