diff --git a/pyproject.toml b/pyproject.toml index 0760ba61..ccb5a736 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ssb-sgis" -version = "1.1.0" +version = "1.1.1" description = "GIS functions used at Statistics Norway." authors = ["Morten Letnes "] license = "MIT" diff --git a/src/sgis/__init__.py b/src/sgis/__init__.py index e6cd6ad0..fb32115f 100644 --- a/src/sgis/__init__.py +++ b/src/sgis/__init__.py @@ -1,10 +1,6 @@ -config = { - "n_jobs": 1, -} - - import sgis.raster.indices as indices +from .conf import config from .geopandas_tools.bounds import Gridlooper from .geopandas_tools.bounds import bounds_to_points from .geopandas_tools.bounds import bounds_to_polygon diff --git a/src/sgis/conf.py b/src/sgis/conf.py new file mode 100644 index 00000000..d6223ad8 --- /dev/null +++ b/src/sgis/conf.py @@ -0,0 +1,16 @@ +try: + from gcsfs import GCSFileSystem +except ImportError: + + class GCSFileSystem: + """Placeholder.""" + + def __init__(self, *args, **kwargs) -> None: + """Placeholder.""" + raise ImportError("gcsfs") + + +config = { + "n_jobs": 1, + "file_system": GCSFileSystem, +} diff --git a/src/sgis/io/dapla_functions.py b/src/sgis/io/dapla_functions.py index 80672904..d5261a7d 100644 --- a/src/sgis/io/dapla_functions.py +++ b/src/sgis/io/dapla_functions.py @@ -2,13 +2,17 @@ from __future__ import annotations +import functools +import glob import json import multiprocessing import os +import shutil +import uuid from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -import dapla as dp import geopandas as gpd import joblib import pandas as pd @@ -22,10 +26,12 @@ from pandas import DataFrame from pyarrow import ArrowInvalid +from ..geopandas_tools.conversion import to_shapely from ..geopandas_tools.general import get_common_crs from ..geopandas_tools.sfilter import sfilter PANDAS_FALLBACK_INFO = " Set pandas_fallback=True to ignore this error." +from ..conf import config def read_geopandas( @@ -63,7 +69,7 @@ def read_geopandas( A GeoDataFrame if it has rows. If zero rows, a pandas DataFrame is returned. """ if file_system is None: - file_system = dp.FileClient.get_gcs_file_system() + file_system = config["file_system"]() if not isinstance(gcs_path, (str | Path | os.PathLike)): kwargs |= {"file_system": file_system, "pandas_fallback": pandas_fallback} @@ -130,6 +136,18 @@ def read_geopandas( except TypeError as e: raise TypeError(f"Unexpected type {type(gcs_path)}.") from e + if has_partitions(gcs_path, file_system): + filters = kwargs.pop("filters", None) + return _read_partitioned_parquet( + gcs_path, + file_system=file_system, + mask=mask, + pandas_fallback=pandas_fallback, + threads=threads, + filters=filters, + **kwargs, + ) + if "parquet" in gcs_path or "prqt" in gcs_path: with file_system.open(gcs_path, mode="rb") as file: try: @@ -179,31 +197,42 @@ def read_geopandas( def _get_bounds_parquet( path: str | Path, file_system: GCSFileSystem, pandas_fallback: bool = False ) -> tuple[list[float], dict] | tuple[None, None]: - with file_system.open(path) as f: + with file_system.open(path, "rb") as file: + return _get_bounds_parquet_from_open_file(file, file_system) + + +def _get_bounds_parquet_from_open_file( + file, file_system +) -> tuple[list[float], dict] | tuple[None, None]: + geo_metadata = _get_geo_metadata(file, file_system) + if not geo_metadata: + return None, None + return geo_metadata["bbox"], geo_metadata["crs"] + + +def _get_geo_metadata(file, file_system) -> dict: + meta = pq.read_schema(file).metadata + geo_metadata = json.loads(meta[b"geo"]) + try: + primary_column = geo_metadata["primary_column"] + except KeyError as e: + raise KeyError(e, geo_metadata) from e + try: + return geo_metadata["columns"][primary_column] + except KeyError as e: try: - num_rows = pq.read_metadata(f).num_rows + num_rows = pq.read_metadata(file).num_rows except ArrowInvalid as e: - if not file_system.isfile(f): - return None, None - raise ArrowInvalid(e, path) from e + if not file_system.isfile(file): + return {} + raise ArrowInvalid(e, file) from e if not num_rows: - return None, None - meta = pq.read_schema(f).metadata - try: - meta = json.loads(meta[b"geo"])["columns"]["geometry"] - except KeyError as e: - if pandas_fallback: - return None, None - raise KeyError( - f"{e.__class__.__name__}: {e} for {path}." + PANDAS_FALLBACK_INFO, - # f"{num_rows=}", - # meta, - ) from e - return meta["bbox"], meta["crs"] + return {} + return {} def _get_columns(path: str | Path, file_system: GCSFileSystem) -> pd.Index: - with file_system.open(path) as f: + with file_system.open(path, "rb") as f: schema = pq.read_schema(f) index_cols = _get_index_cols(schema) return pd.Index(schema.names).difference(index_cols) @@ -242,8 +271,7 @@ def get_bounds_series( --------- >>> import sgis as sg >>> import dapla as dp - >>> file_system = dp.FileClient.get_gcs_file_system() - >>> all_paths = file_system.ls("...") + >>> all_paths = GCSFileSystem().ls("...") Get the bounds of all your file paths, indexed by path. @@ -275,7 +303,7 @@ def get_bounds_series( """ if file_system is None: - file_system = dp.FileClient.get_gcs_file_system() + file_system = config["file_system"]() if threads is None: threads = min(len(paths), int(multiprocessing.cpu_count())) or 1 @@ -308,7 +336,7 @@ def write_geopandas( overwrite: bool = True, pandas_fallback: bool = False, file_system: GCSFileSystem | None = None, - write_covering_bbox: bool = False, + partition_cols=None, **kwargs, ) -> None: """Writes a GeoDataFrame to the speficied format. @@ -324,13 +352,7 @@ def write_geopandas( not be written with geopandas and the number of rows is more than 0. If True, the file will be written without geo-metadata if >0 rows. file_system: Optional file sustem. - write_covering_bbox: Writes the bounding box column for each row entry with column name "bbox". - Writing a bbox column can be computationally expensive, but allows you to specify - a bbox in : func:read_parquet for filtered reading. - Note: this bbox column is part of the newer GeoParquet 1.1 specification and should be - considered as experimental. While writing the column is backwards compatible, using it - for filtering may not be supported by all readers. - + partition_cols: Column(s) to partition by. Only for parquet files. **kwargs: Additional keyword arguments passed to parquet.write_table (for parquet) or geopandas' to_file method (if not parquet). """ @@ -340,22 +362,25 @@ def write_geopandas( except TypeError as e: raise TypeError(f"Unexpected type {type(gcs_path)}.") from e - if not overwrite and exists(gcs_path): + if file_system is None: + file_system = config["file_system"]() + + if not overwrite and file_system.exists(gcs_path): raise ValueError("File already exists.") if not isinstance(df, GeoDataFrame): raise ValueError("DataFrame must be GeoDataFrame.") - if file_system is None: - file_system = dp.FileClient.get_gcs_file_system() - - if not len(df): + if not len(df) and has_partitions(gcs_path, file_system): + return + elif not len(df): if pandas_fallback: df = pd.DataFrame(df) df.geometry = df.geometry.astype(str) df.geometry = None try: - dp.write_pandas(df, gcs_path, **kwargs) + with file_system.open(gcs_path, "wb") as file: + df.to_parquet(gcs_path, **kwargs) except Exception as e: more_txt = PANDAS_FALLBACK_INFO if not pandas_fallback else "" raise e.__class__( @@ -363,17 +388,22 @@ def write_geopandas( ) from e return - file_system = dp.FileClient.get_gcs_file_system() - if ".parquet" in gcs_path or "prqt" in gcs_path: - with file_system.open(gcs_path, mode="wb") as buffer: + if partition_cols is not None: + return _write_partitioned_geoparquet( + df, + gcs_path, + partition_cols, + file_system, + **kwargs, + ) + with file_system.open(gcs_path, mode="wb") as file: table = _geopandas_to_arrow( df, index=df.index, schema_version=None, - write_covering_bbox=write_covering_bbox, ) - pq.write_table(table, buffer, compression="snappy", **kwargs) + pq.write_table(table, file, compression="snappy", **kwargs) return layer = kwargs.pop("layer", None) @@ -393,17 +423,156 @@ def write_geopandas( df.to_file(file, driver=driver, layer=layer) -def exists(path: str | Path) -> bool: - """Returns True if the path exists, and False if it doesn't. +def _remove_file(path, file_system) -> None: + try: + file_system.rm_file(path) + except (AttributeError, TypeError, PermissionError): + try: + shutil.rmtree(path) + except NotADirectoryError: + try: + os.remove(path) + except PermissionError: + pass - Args: - path (str): The path to the file or directory. - Returns: - True if the path exists, False if not. - """ - file_system = dp.FileClient.get_gcs_file_system() - return file_system.exists(path) +def _write_partitioned_geoparquet(df, path, partition_cols, file_system, **kwargs): + path = Path(path) + unique_id = uuid.uuid4() + + try: + glob_func = functools.partial(file_system.glob, detail=False) + except AttributeError: + glob_func = functools.partial(glob.glob, recursive=True) + + args: list[tuple[Path, DataFrame]] = [] + dirs: list[Path] = set() + for group, rows in df.groupby(partition_cols): + name = ( + "/".join( + f"{col}={value}" + for col, value in zip(partition_cols, group, strict=True) + ) + + f"/{unique_id}.parquet" + ) + + dirs.add((path / name).parent) + args.append((path / name, rows)) + + if file_system.exists(path) and not has_partitions(path, file_system): + _remove_file(path, file_system) + + for dir_ in dirs: + try: + os.makedirs(dir_, exist_ok=True) + except (OSError, FileNotFoundError, FileExistsError) as e: + print(e) + pass + + def threaded_write(path_rows): + new_path, rows = path_rows + for sibling_path in glob_func(str(Path(new_path).with_name("**"))): + if not paths_are_equal(sibling_path, Path(new_path).parent): + _remove_file(sibling_path, file_system) + with file_system.open(new_path, mode="wb") as file: + table = _geopandas_to_arrow( + rows, + index=df.index, + schema_version=None, + ) + pq.write_table(table, file, compression="snappy", **kwargs) + + with ThreadPoolExecutor() as executor: + list(executor.map(threaded_write, args)) + + +def _read_partitioned_parquet( + path, filters, file_system, mask, pandas_fallback, threads, **kwargs +): + try: + glob_func = functools.partial(file_system.glob, detail=False) + except AttributeError: + glob_func = functools.partial(glob.glob, recursive=True) + + filters = filters or [] + new_filters = [] + for filt in filters: + if "in" in filt: + values = [ + x.strip("(") + .strip(")") + .strip("[") + .strip("]") + .strip("{") + .strip("}") + .strip(" ") + for x in filt[-1].split(",") + ] + filt = [filt[0] + "=" + x for x in values] + else: + filt = ["".join(filt)] + new_filters.append(filt) + + def intersects(file, mask) -> bool: + bbox, _ = _get_bounds_parquet_from_open_file(file, file_system) + return shapely.box(*bbox).intersects(to_shapely(mask)) + + def read(path) -> GeoDataFrame | None: + with file_system.open(path, "rb") as file: + if mask is not None and not intersects(file, mask): + return + + schema = kwargs.pop("schema", pq.read_schema(file)) + + return gpd.read_parquet(file, schema=schema, **kwargs) + + with ThreadPoolExecutor() as executor: + results = [ + x + for x in ( + executor.map( + read, + ( + path + for path in glob_func(str(Path(path) / "**/*.parquet")) + if all( + any(subfilt in Path(path).parts for subfilt in filt) + for filt in new_filters + ) + ), + ) + ) + if x is not None + ] + if results: + if mask is not None: + return sfilter(pd.concat(results), mask) + return pd.concat(results) + + # add columns to empty DataFrame + first_path = next(iter(glob_func(str(Path(path) / "**/*.parquet")))) + return gpd.GeoDataFrame( + columns=list(dict.fromkeys(_get_columns(first_path, file_system))) + ) + + +def paths_are_equal(path1: Path | str, path2: Path | str) -> bool: + return Path(path1).parts == Path(path2).parts + + +def has_partitions(path, file_system) -> bool: + try: + glob_func = functools.partial(file_system.glob, detail=False) + except AttributeError: + glob_func = functools.partial(glob.glob, recursive=True) + + return bool( + [ + x + for x in glob_func(str(Path(path) / "**/*.parquet")) + if not paths_are_equal(x, path) + ] + ) def check_files( @@ -419,7 +588,7 @@ def check_files( within_minutes: Optionally include only files that were updated in the last n minutes. """ - file_system = dp.FileClient.get_gcs_file_system() + file_system = config["file_system"]() # (recursive doesn't work, so doing recursive search below) info = file_system.ls(folder, detail=True, recursive=True) @@ -474,7 +643,7 @@ def check_files( def _get_files_in_subfolders(folderinfo: list[dict]) -> list[tuple]: - file_system = dp.FileClient.get_gcs_file_system() + file_system = config["file_system"]() fileinfo = []