Skip to content

Commit 209421c

Browse files
committed
Remove code duplication in constituent file loading (more kwargs)
1 parent c3fc7f6 commit 209421c

File tree

1 file changed

+107
-82
lines changed

1 file changed

+107
-82
lines changed

shapefile.py

Lines changed: 107 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections.abc import Collection
2121
from datetime import date
2222
from struct import Struct, calcsize, error, pack, unpack
23-
from typing import Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union
23+
from typing import IO, Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union
2424
from urllib.error import HTTPError
2525
from urllib.parse import urlparse, urlunparse
2626
from urllib.request import Request, urlopen
@@ -912,6 +912,10 @@ class ShapefileException(Exception):
912912
pass
913913

914914

915+
class _NoShpSentinel(object):
916+
pass
917+
918+
915919
class Reader:
916920
"""Reads the three files of a shapefile as a unit or
917921
separately. If one of the three files (.shp, .shx,
@@ -933,10 +937,25 @@ class Reader:
933937
but they can be.
934938
"""
935939

936-
def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs):
937-
self.shp = None
938-
self.shx = None
939-
self.dbf = None
940+
CONSTITUENT_FILE_EXTS = ["shp", "shx", "dbf"]
941+
assert all(ext.islower() for ext in CONSTITUENT_FILE_EXTS)
942+
943+
def _assert_ext_is_supported(self, ext: str):
944+
assert ext in self.CONSTITUENT_FILE_EXTS
945+
946+
def __init__(
947+
self,
948+
*args,
949+
encoding="utf-8",
950+
encodingErrors="strict",
951+
shp=_NoShpSentinel,
952+
shx=None,
953+
dbf=None,
954+
**kwargs,
955+
):
956+
# self.shp = None
957+
# self.shx = None
958+
# self.dbf = None
940959
self._files_to_close = []
941960
self.shapeName = "Not specified"
942961
self._offsets = []
@@ -1014,19 +1033,20 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs):
10141033
shapefile = os.path.splitext(shapefile)[
10151034
0
10161035
] # root shapefile name
1017-
for ext in ["SHP", "SHX", "DBF", "shp", "shx", "dbf"]:
1018-
try:
1019-
member = archive.open(shapefile + "." + ext)
1020-
# write zipfile member data to a read+write tempfile and use as source, gets deleted on close()
1021-
fileobj = tempfile.NamedTemporaryFile(
1022-
mode="w+b", delete=True
1023-
)
1024-
fileobj.write(member.read())
1025-
fileobj.seek(0)
1026-
setattr(self, ext.lower(), fileobj)
1027-
self._files_to_close.append(fileobj)
1028-
except:
1029-
pass
1036+
for lower_ext in self.CONSTITUENT_FILE_EXTS:
1037+
for cased_ext in [lower_ext, lower_ext.upper()]:
1038+
try:
1039+
member = archive.open(f"{shapefile}.{cased_ext}")
1040+
# write zipfile member data to a read+write tempfile and use as source, gets deleted on close()
1041+
fileobj = tempfile.NamedTemporaryFile(
1042+
mode="w+b", delete=True
1043+
)
1044+
fileobj.write(member.read())
1045+
fileobj.seek(0)
1046+
setattr(self, lower_ext, fileobj)
1047+
self._files_to_close.append(fileobj)
1048+
except:
1049+
pass
10301050
# Close and delete the temporary zipfile
10311051
try:
10321052
zipfileobj.close()
@@ -1086,46 +1106,47 @@ def __init__(self, *args, encoding="utf-8", encodingErrors="strict", **kwargs):
10861106
self.load(path)
10871107
return
10881108

1089-
# Otherwise, load from separate shp/shx/dbf args (must be path or file-like)
1090-
if "shp" in kwargs:
1091-
if hasattr(kwargs["shp"], "read"):
1092-
self.shp = kwargs["shp"]
1093-
# Copy if required
1094-
try:
1095-
self.shp.seek(0)
1096-
except (NameError, io.UnsupportedOperation):
1097-
self.shp = io.BytesIO(self.shp.read())
1098-
else:
1099-
(baseName, ext) = os.path.splitext(kwargs["shp"])
1100-
self.load_shp(baseName)
1101-
1102-
if "shx" in kwargs:
1103-
if hasattr(kwargs["shx"], "read"):
1104-
self.shx = kwargs["shx"]
1105-
# Copy if required
1106-
try:
1107-
self.shx.seek(0)
1108-
except (NameError, io.UnsupportedOperation):
1109-
self.shx = io.BytesIO(self.shx.read())
1110-
else:
1111-
(baseName, ext) = os.path.splitext(kwargs["shx"])
1112-
self.load_shx(baseName)
1109+
if shp is _NoShpSentinel:
1110+
self.shp = None
1111+
self.shx = None
1112+
else:
1113+
self.shp = self._seek_0_on_file_obj_wrap_or_open_from_name("shp", shp)
1114+
self.shx = self._seek_0_on_file_obj_wrap_or_open_from_name("shx", shx)
11131115

1114-
if "dbf" in kwargs:
1115-
if hasattr(kwargs["dbf"], "read"):
1116-
self.dbf = kwargs["dbf"]
1117-
# Copy if required
1118-
try:
1119-
self.dbf.seek(0)
1120-
except (NameError, io.UnsupportedOperation):
1121-
self.dbf = io.BytesIO(self.dbf.read())
1122-
else:
1123-
(baseName, ext) = os.path.splitext(kwargs["dbf"])
1124-
self.load_dbf(baseName)
1116+
self.dbf = self._seek_0_on_file_obj_wrap_or_open_from_name("dbf", dbf)
11251117

11261118
# Load the files
11271119
if self.shp or self.dbf:
1128-
self.load()
1120+
self._try_to_set_constituent_file_headers()
1121+
1122+
def _seek_0_on_file_obj_wrap_or_open_from_name(
1123+
self,
1124+
ext: str,
1125+
# File name, file object or anything with a read() method that returns bytes.
1126+
# TODO: Create simple Protocol with a read() method
1127+
file_: Optional[Union[str, IO[bytes]]],
1128+
) -> Union[None, io.BytesIO, IO[bytes]]:
1129+
# assert ext in {'shp', 'dbf', 'shx'}
1130+
self._assert_ext_is_supported(ext)
1131+
1132+
if file_ is None:
1133+
return None
1134+
1135+
if isinstance(file_, str):
1136+
baseName, __ = os.path.splitext(file_)
1137+
return self._load_constituent_file(baseName, ext)
1138+
1139+
if hasattr(file_, "read"):
1140+
# Copy if required
1141+
try:
1142+
file_.seek(0) # type: ignore
1143+
return file_
1144+
except (NameError, io.UnsupportedOperation):
1145+
return io.BytesIO(file_.read())
1146+
1147+
raise ShapefileException(
1148+
f"Could not load shapefile constituent file from: {file_}"
1149+
)
11291150

11301151
def __str__(self):
11311152
"""
@@ -1232,57 +1253,61 @@ def load(self, shapefile=None):
12321253
raise ShapefileException(
12331254
f"Unable to open {shapeName}.dbf or {shapeName}.shp."
12341255
)
1256+
self._try_to_set_constituent_file_headers()
1257+
1258+
def _try_to_set_constituent_file_headers(self):
12351259
if self.shp:
12361260
self.__shpHeader()
12371261
if self.dbf:
12381262
self.__dbfHeader()
12391263
if self.shx:
12401264
self.__shxHeader()
12411265

1242-
def load_shp(self, shapefile_name):
1266+
def _try_get_open_constituent_file(self, shapefile_name: str, ext: str):
12431267
"""
1244-
Attempts to load file with .shp extension as both lower and upper case
1268+
Attempts to open a .shp, .dbf or .shx file,
1269+
with both lower case and upper case file extensions,
1270+
and return it. If it was not possible to open the file, None is returned.
12451271
"""
1246-
shp_ext = "shp"
1272+
# typing.LiteralString is only available from PYthon 3.11 onwards.
1273+
# https://docs.python.org/3/library/typing.html#typing.LiteralString
1274+
self._assert_ext_is_supported(ext)
12471275
try:
1248-
self.shp = open(f"{shapefile_name}.{shp_ext}", "rb")
1249-
self._files_to_close.append(self.shp)
1276+
return open(f"{shapefile_name}.{ext}", "rb")
12501277
except OSError:
12511278
try:
1252-
self.shp = open(f"{shapefile_name}.{shp_ext.upper()}", "rb")
1253-
self._files_to_close.append(self.shp)
1279+
return open(f"{shapefile_name}.{ext.upper()}", "rb")
12541280
except OSError:
1255-
pass
1281+
return None
1282+
1283+
def _load_constituent_file(self, shapefile_name: str, ext: str):
1284+
"""
1285+
Attempts to open a .shp, .dbf or .shx file, with the extension
1286+
as both lower and upper case, and if successful append it to
1287+
self._files_to_close.
1288+
"""
1289+
shp_dbf_or_dhx_file = self._try_get_open_constituent_file(shapefile_name, ext)
1290+
if shp_dbf_or_dhx_file is not None:
1291+
self._files_to_close.append(shp_dbf_or_dhx_file)
1292+
return shp_dbf_or_dhx_file
1293+
1294+
def load_shp(self, shapefile_name):
1295+
"""
1296+
Attempts to load file with .shp extension as both lower and upper case
1297+
"""
1298+
self.shp = self._load_constituent_file(shapefile_name, "shp")
12561299

12571300
def load_shx(self, shapefile_name):
12581301
"""
12591302
Attempts to load file with .shx extension as both lower and upper case
12601303
"""
1261-
shx_ext = "shx"
1262-
try:
1263-
self.shx = open(f"{shapefile_name}.{shx_ext}", "rb")
1264-
self._files_to_close.append(self.shx)
1265-
except OSError:
1266-
try:
1267-
self.shx = open(f"{shapefile_name}.{shx_ext.upper()}", "rb")
1268-
self._files_to_close.append(self.shx)
1269-
except OSError:
1270-
pass
1304+
self.shx = self._load_constituent_file(shapefile_name, "shx")
12711305

12721306
def load_dbf(self, shapefile_name):
12731307
"""
12741308
Attempts to load file with .dbf extension as both lower and upper case
12751309
"""
1276-
dbf_ext = "dbf"
1277-
try:
1278-
self.dbf = open(f"{shapefile_name}.{dbf_ext}", "rb")
1279-
self._files_to_close.append(self.dbf)
1280-
except OSError:
1281-
try:
1282-
self.dbf = open(f"{shapefile_name}.{dbf_ext.upper()}", "rb")
1283-
self._files_to_close.append(self.dbf)
1284-
except OSError:
1285-
pass
1310+
self.dbf = self._load_constituent_file(shapefile_name, "dbf")
12861311

12871312
def __del__(self):
12881313
self.close()

0 commit comments

Comments
 (0)