Skip to content

Commit 0bde698

Browse files
committed
Update shapefile.py
1 parent 701e9b0 commit 0bde698

File tree

1 file changed

+58
-20
lines changed

1 file changed

+58
-20
lines changed

shapefile.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@
103103
# File name, file object or anything with a read() method that returns bytes.
104104
# TODO: Create simple Protocol with a read() method
105105
BinaryFileT = Union[str, IO[bytes]]
106+
BinaryFileStreamT = Union[IO[bytes], io.BytesIO]
107+
108+
RecordValue = Union[float, str, date]
106109

107110

108111
class GeoJsonShapeT(TypedDict):
@@ -717,7 +720,12 @@ class _Record(list):
717720
>>> print(r.ID)
718721
"""
719722

720-
def __init__(self, field_positions, values, oid=None):
723+
def __init__(
724+
self,
725+
field_positions: dict[str, int],
726+
values: Iterable[RecordValue],
727+
oid: Optional[int] = None,
728+
):
721729
"""
722730
A Record should be created by the Reader class
723731
@@ -732,7 +740,7 @@ def __init__(self, field_positions, values, oid=None):
732740
self.__oid = -1
733741
list.__init__(self, values)
734742

735-
def __getattr__(self, item):
743+
def __getattr__(self, item: str) -> RecordValue:
736744
"""
737745
__getattr__ is called if an attribute is used that does
738746
not exist in the normal sense. For example r=Record(...), r.ID
@@ -755,7 +763,7 @@ def __getattr__(self, item):
755763
f"{item} found as a field but not enough values available."
756764
)
757765

758-
def __setattr__(self, key, value):
766+
def __setattr__(self, key: str, value: RecordValue):
759767
"""
760768
Sets a value of a field attribute
761769
:param key: The field name
@@ -811,11 +819,11 @@ def __setitem__(self, key, value):
811819
raise IndexError(f"{key} is not a field name and not an int")
812820

813821
@property
814-
def oid(self):
822+
def oid(self) -> int:
815823
"""The index position of the record in the original shapefile"""
816824
return self.__oid
817825

818-
def as_dict(self, date_strings=False):
826+
def as_dict(self, date_strings: bool = False) -> dict[str, RecordValue]:
819827
"""
820828
Returns this Record as a dictionary using the field names as keys
821829
:return: dict
@@ -830,7 +838,7 @@ def as_dict(self, date_strings=False):
830838
def __repr__(self):
831839
return f"Record #{self.__oid}: {list(self)}"
832840

833-
def __dir__(self):
841+
def __dir__(self) -> list[str]:
834842
"""
835843
Helps to show the field names in an interactive environment like IPython.
836844
See: http://ipython.readthedocs.io/en/stable/config/integrating.html
@@ -856,7 +864,7 @@ class ShapeRecord:
856864
"""A ShapeRecord object containing a shape along with its attributes.
857865
Provides the GeoJSON __geo_interface__ to return a Feature dictionary."""
858866

859-
def __init__(self, shape=None, record=None):
867+
def __init__(self, shape: Optional[Shape] = None, record: Optional[_Record] = None):
860868
self.shape = shape
861869
self.record = record
862870

@@ -967,12 +975,12 @@ def __init__(
967975
self.shp = None
968976
self.shx = None
969977
self.dbf = None
970-
self._files_to_close: list[IO[bytes]] = []
978+
self._files_to_close: list[BinaryFileStreamT] = []
971979
self.shapeName = "Not specified"
972980
self._offsets: list[int] = []
973-
self.shpLength = None
974-
self.numRecords = None
975-
self.numShapes = None
981+
self.shpLength: Optional[int] = None
982+
self.numRecords: Optional[int] = None
983+
self.numShapes: Optional[int] = None
976984
self.fields: list[list[str]] = []
977985
self.__dbfHdrLength = 0
978986
self.__fieldLookup: dict[str, int] = {}
@@ -1131,7 +1139,7 @@ def __seek_0_on_file_obj_wrap_or_open_from_name(
11311139
self,
11321140
ext: str,
11331141
file_: Optional[BinaryFileT],
1134-
) -> Union[None, io.BytesIO, IO[bytes]]:
1142+
) -> Union[None, IO[bytes]]:
11351143
# assert ext in {'shp', 'dbf', 'shx'}
11361144
self._assert_ext_is_supported(ext)
11371145

@@ -1615,6 +1623,7 @@ def __dbfHeader(self):
16151623
self.numRecords, self.__dbfHdrLength, self.__recordLength = unpack(
16161624
"<xxxxLHH20x", dbf.read(32)
16171625
)
1626+
16181627
# read fields
16191628
numFields = (self.__dbfHdrLength - 33) // 32
16201629
for field in range(numFields):
@@ -1709,7 +1718,13 @@ def __recordFields(self, fields=None):
17091718
recLookup = self.__fullRecLookup
17101719
return fieldTuples, recLookup, recStruct
17111720

1712-
def __record(self, fieldTuples, recLookup, recStruct, oid=None):
1721+
def __record(
1722+
self,
1723+
fieldTuples: list[tuple[str, str, int, bool]],
1724+
recLookup: dict[str, int],
1725+
recStruct: Struct,
1726+
oid: Optional[int] = None,
1727+
) -> Optional[_Record]:
17131728
"""Reads and returns a dbf record row as a list of values. Requires specifying
17141729
a list of field info tuples 'fieldTuples', a record name-index dict 'recLookup',
17151730
and a Struct instance 'recStruct' for unpacking these fields.
@@ -1801,7 +1816,9 @@ def __record(self, fieldTuples, recLookup, recStruct, oid=None):
18011816

18021817
return _Record(recLookup, record, oid)
18031818

1804-
def record(self, i=0, fields=None):
1819+
def record(
1820+
self, i: int = 0, fields: Optional[list[str]] = None
1821+
) -> Optional[_Record]:
18051822
"""Returns a specific dbf record based on the supplied index.
18061823
To only read some of the fields, specify the 'fields' arg as a
18071824
list of one or more fieldnames.
@@ -1818,7 +1835,7 @@ def record(self, i=0, fields=None):
18181835
oid=i, fieldTuples=fieldTuples, recLookup=recLookup, recStruct=recStruct
18191836
)
18201837

1821-
def records(self, fields=None):
1838+
def records(self, fields: Optional[list[str]] = None) -> list[_Record]:
18221839
"""Returns all records in a dbf file.
18231840
To only read some of the fields, specify the 'fields' arg as a
18241841
list of one or more fieldnames.
@@ -1829,15 +1846,20 @@ def records(self, fields=None):
18291846
f = self.__getFileObj(self.dbf)
18301847
f.seek(self.__dbfHdrLength)
18311848
fieldTuples, recLookup, recStruct = self.__recordFields(fields)
1832-
for i in range(self.numRecords):
1849+
for i in range(self.numRecords): # type: ignore
18331850
r = self.__record(
18341851
oid=i, fieldTuples=fieldTuples, recLookup=recLookup, recStruct=recStruct
18351852
)
18361853
if r:
18371854
records.append(r)
18381855
return records
18391856

1840-
def iterRecords(self, fields=None, start=0, stop=None):
1857+
def iterRecords(
1858+
self,
1859+
fields=Optional[list[str]],
1860+
start: int = 0,
1861+
stop: Optional[int] = None,
1862+
) -> Iterator[Optional[_Record]]:
18411863
"""Returns a generator of records in a dbf file.
18421864
Useful for large shapefiles or dbf files.
18431865
To only read some of the fields, specify the 'fields' arg as a
@@ -1851,6 +1873,8 @@ def iterRecords(self, fields=None, start=0, stop=None):
18511873
"""
18521874
if self.numRecords is None:
18531875
self.__dbfHeader()
1876+
if not isinstance(self.numRecords, int):
1877+
raise Exception("Error when reading number of Records in dbf file header")
18541878
f = self.__getFileObj(self.dbf)
18551879
start = self.__restrictIndex(start)
18561880
if stop is None:
@@ -1871,7 +1895,12 @@ def iterRecords(self, fields=None, start=0, stop=None):
18711895
if r:
18721896
yield r
18731897

1874-
def shapeRecord(self, i=0, fields=None, bbox=None):
1898+
def shapeRecord(
1899+
self,
1900+
i: int = 0,
1901+
fields: Optional[list[str]] = None,
1902+
bbox: Optional[BBox] = None,
1903+
) -> Optional[ShapeRecord]:
18751904
"""Returns a combination geometry and attribute record for the
18761905
supplied record index.
18771906
To only read some of the fields, specify the 'fields' arg as a
@@ -1884,8 +1913,13 @@ def shapeRecord(self, i=0, fields=None, bbox=None):
18841913
if shape:
18851914
record = self.record(i, fields=fields)
18861915
return ShapeRecord(shape=shape, record=record)
1916+
return None
18871917

1888-
def shapeRecords(self, fields=None, bbox=None):
1918+
def shapeRecords(
1919+
self,
1920+
fields: Optional[list[str]] = None,
1921+
bbox: Optional[BBox] = None,
1922+
) -> ShapeRecords:
18891923
"""Returns a list of combination geometry/attribute records for
18901924
all records in a shapefile.
18911925
To only read some of the fields, specify the 'fields' arg as a
@@ -1895,7 +1929,11 @@ def shapeRecords(self, fields=None, bbox=None):
18951929
"""
18961930
return ShapeRecords(self.iterShapeRecords(fields=fields, bbox=bbox))
18971931

1898-
def iterShapeRecords(self, fields=None, bbox=None):
1932+
def iterShapeRecords(
1933+
self,
1934+
fields: Optional[list[str]] = None,
1935+
bbox: Optional[BBox] = None,
1936+
) -> Iterator[ShapeRecord]:
18991937
"""Returns a generator of combination geometry/attribute records for
19001938
all records in a shapefile.
19011939
To only read some of the fields, specify the 'fields' arg as a

0 commit comments

Comments
 (0)