Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor storage #133

Merged
merged 8 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 34 additions & 110 deletions pangeo_forge_recipes/recipes/xarray_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@
A Pangeo Forge Recipe
"""

import json
import logging
import os
import tempfile
import warnings
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from itertools import product
from typing import Callable, Dict, List, Optional, Sequence, Tuple

import dask
import fsspec
import numpy as np
import xarray as xr
import zarr

from ..patterns import FilePattern
from ..storage import AbstractTarget, UninitializedTarget, UninitializedTargetError
from ..storage import AbstractTarget, CacheFSSpecTarget, MetadataTarget, file_opener
from ..utils import (
chunk_bounds_and_conflicts,
chunked_iterable,
Expand Down Expand Up @@ -47,59 +43,12 @@ def _chunk_metadata_fname(chunk_key) -> str:
return "chunk-meta-" + _encode_key(chunk_key) + ".json"


def _copy_btw_filesystems(input_opener, output_opener, BLOCK_SIZE=10_000_000):
with input_opener as source:
with output_opener as target:
while True:
data = source.read(BLOCK_SIZE)
if not data:
break
target.write(data)


@contextmanager
def _maybe_open_or_copy_to_local(opener, copy_to_local, orig_name):
_, suffix = os.path.splitext(orig_name)
if copy_to_local:
ntf = tempfile.NamedTemporaryFile(suffix=suffix)
tmp_name = ntf.name
logger.info(f"Copying {orig_name} to local file {tmp_name}")
target_opener = open(tmp_name, mode="wb")
_copy_btw_filesystems(opener, target_opener)
yield tmp_name
ntf.close() # cleans up the temporary file
else:
with opener as fp:
with fp as fp2:
yield fp2


@contextmanager
def _fsspec_safe_open(fname, **kwargs):
# workaround for inconsistent behavior of fsspec.open
# https://github.com/intake/filesystem_spec/issues/579
with fsspec.open(fname, **kwargs) as fp:
with fp as fp2:
yield fp2


def _get_url_size(fname):
with fsspec.open(fname, mode="rb") as of:
size = of.size
return size

ChunkKey = Tuple[int]
InputKey = Tuple[int]

# Notes about dataclasses:
# - https://www.python.org/dev/peps/pep-0557/#inheritance
# - https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
# The main awkward thing here is that, because we are using multiple inheritance
# with dataclasses, _ALL_ fields must have default values. This makes it impossible
# to have "required" keyword arguments--everything needs some default.
# That's whay we end up with `UninitializedTarget` and `_variable_sequence_pattern_default_factory`


ChunkKey = Tuple[int]
InputKey = Tuple[int]


@dataclass
Expand Down Expand Up @@ -140,9 +89,9 @@ class XarrayZarrRecipe(BaseRecipe):
file_pattern: FilePattern
inputs_per_chunk: Optional[int] = 1
target_chunks: Dict[str, int] = field(default_factory=dict)
target: AbstractTarget = field(default_factory=UninitializedTarget)
input_cache: AbstractTarget = field(default_factory=UninitializedTarget)
metadata_cache: AbstractTarget = field(default_factory=UninitializedTarget)
target: Optional[AbstractTarget] = None
input_cache: Optional[CacheFSSpecTarget] = None
metadata_cache: Optional[MetadataTarget] = None
cache_inputs: bool = True
copy_input_to_local_file: bool = False
consolidate_zarr: bool = True
Expand Down Expand Up @@ -247,6 +196,8 @@ def _set_target_chunks(self):
@property # type: ignore
@closure
def prepare_target(self) -> None:
if self.target is None:
raise ValueError("target is not set.")
try:
ds = self.open_target()
logger.info("Found an existing dataset in target")
Expand Down Expand Up @@ -307,38 +258,31 @@ def prepare_target(self) -> None:
self.expand_target_dim(self._concat_dim, n_sequence)

if self._cache_metadata:
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set")
# if nitems_per_input is not constant, we need to cache this info
recipe_meta = {"input_sequence_lens": input_sequence_lens}
meta_mapper = self.metadata_cache.get_mapper()
# we are saving a dictionary with one key (input_sequence_lens)
logger.info("Caching global metadata")
meta_mapper[_GLOBAL_METADATA_KEY] = json.dumps(recipe_meta).encode("utf-8")
self.metadata_cache[_GLOBAL_METADATA_KEY] = recipe_meta

# TODO: figure out how to make mypy happy with this convoluted structure
@property # type: ignore
@closure
def cache_input(self, input_key: InputKey) -> None: # type: ignore
logger.info(f"Caching input {input_key}")
fname = self.file_pattern[input_key]
if self.cache_inputs:
if self.input_cache is None:
raise ValueError("input_cache is not set.")
logger.info(f"Caching input '{input_key}'")
fname = self.file_pattern[input_key]
self.input_cache.cache_file(fname, **self.fsspec_open_kwargs)

# check and see if the file already exists in the cache
if self.input_cache.exists(fname):
cached_size = self.input_cache.size(fname)
remote_size = _get_url_size(fname)
if cached_size == remote_size:
logger.info(f"Input {input_key} file {fname} is already cached")
return

input_opener = _fsspec_safe_open(fname, mode="rb", **self.fsspec_open_kwargs)
target_opener = self.input_cache.open(fname, mode="wb")
_copy_btw_filesystems(input_opener, target_opener)
# TODO: make it so we can cache metadata WITHOUT copying the file
if self._cache_metadata:
self.cache_input_metadata(input_key)

@property # type: ignore
@closure
def store_chunk(self, chunk_key: ChunkKey) -> None: # type: ignore
if self.target is None:
raise ValueError("target has not been set.")
with self.open_chunk(chunk_key) as ds_chunk:
# writing a region means that all the variables MUST have concat_dim
to_drop = [v for v in ds_chunk.variables if self._concat_dim not in ds_chunk[v].dims]
Expand Down Expand Up @@ -377,42 +321,20 @@ def store_chunk(self, chunk_key: ChunkKey) -> None: # type: ignore
@property # type: ignore
@closure
def finalize_target(self) -> None:
if self.target is None:
raise ValueError("target has not been set.")
if self.consolidate_zarr:
logger.info("Consolidating Zarr metadata")
target_mapper = self.target.get_mapper()
zarr.consolidate_metadata(target_mapper)

@contextmanager
def input_opener(self, fname: str):
try:
logger.info(f"Opening '{fname}' from cache")
opener = self.input_cache.open(fname, mode="rb")
with _maybe_open_or_copy_to_local(opener, self.copy_input_to_local_file, fname) as fp:
yield fp
except (IOError, FileNotFoundError, UninitializedTargetError) as err:
if self.cache_inputs:
raise Exception(
f"You are trying to open input {fname}, but the file is "
"not cached yet. First call `cache_input` or set "
"`cache_inputs=False`."
) from err
logger.info(f"No cache found. Opening input `{fname}` directly.")
opener = _fsspec_safe_open(fname, mode="rb", **self.fsspec_open_kwargs)
with _maybe_open_or_copy_to_local(opener, self.copy_input_to_local_file, fname) as fp:
yield fp

@contextmanager
def open_input(self, input_key: InputKey):
fname = self.file_pattern[input_key]
logger.info(f"Opening input with Xarray {input_key}: '{fname}'")
with self.input_opener(fname) as f:
cache = self.input_cache if self.cache_inputs else None
with file_opener(fname, cache=cache, copy_to_local=self.copy_input_to_local_file) as f:
ds = xr.open_dataset(f, **self.xarray_open_kwargs)
# Explicitly load into memory;
# if we don't do this, we get a ValueError: seek of closed file.
# But there will be some cases where we really don't want to load.
# how to keep around the open file object?
# ds = ds.load()

ds = fix_scalar_attr_encoding(ds)

if self.delete_input_encoding:
Expand All @@ -426,11 +348,12 @@ def open_input(self, input_key: InputKey):
yield ds

def cache_input_metadata(self, input_key: InputKey):
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set.")
logger.info(f"Caching metadata for input '{input_key}'")
with self.open_input(input_key) as ds:
metadata = ds.to_dict(data=False)
mapper = self.metadata_cache.get_mapper()
mapper[_input_metadata_fname(input_key)] = json.dumps(metadata).encode("utf-8")
input_metadata = ds.to_dict(data=False)
self.metadata_cache[_input_metadata_fname(input_key)] = input_metadata

@contextmanager
def open_chunk(self, chunk_key: ChunkKey):
Expand Down Expand Up @@ -498,9 +421,10 @@ def region_and_conflicts_for_chunk(self, chunk_key: ChunkKey):
self._concat_dim # type: ignore
]
else:
input_sequence_lens = json.loads(
self.metadata_cache.get_mapper()[_GLOBAL_METADATA_KEY]
)["input_sequence_lens"]
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set.")
global_metadata = self.metadata_cache[_GLOBAL_METADATA_KEY]
input_sequence_lens = global_metadata["input_sequence_lens"]

chunk_bounds, all_chunk_conflicts = chunk_bounds_and_conflicts(
input_sequence_lens, self._concat_dim_chunks # type: ignore
Expand All @@ -523,10 +447,10 @@ def iter_chunks(self):
yield k

def get_input_meta(self, *input_keys: Sequence[InputKey]) -> Dict:
meta_mapper = self.metadata_cache.get_mapper()
# getitems should be async; much faster than serial calls
all_meta_raw = meta_mapper.getitems([_input_metadata_fname(k) for k in input_keys])
return {k: json.loads(raw_bytes) for k, raw_bytes in all_meta_raw.items()}
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set.")
return self.metadata_cache.getitems([_input_metadata_fname(k) for k in input_keys])

def input_position(self, input_key):
# returns the index position of an input key wrt the concat_dim
Expand Down
Loading