From b44a50cb9d02ff6a5335748300bef2b19d8a52fe Mon Sep 17 00:00:00 2001 From: Konstantinos Tsitsimpikos Date: Tue, 16 Nov 2021 17:29:07 +0200 Subject: [PATCH] [Fix/Enhancement] Adding default argument current time timestamp in open 'w' mode (#89) * Adding in every save operation current time timestamp --- tiledb/ml/models/base.py | 5 +++++ tiledb/ml/models/pytorch.py | 10 ++++++++-- tiledb/ml/models/sklearn.py | 10 ++++++++-- tiledb/ml/models/tensorflow_keras.py | 9 +++++++-- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/tiledb/ml/models/base.py b/tiledb/ml/models/base.py index 5a251315..82ccddc2 100644 --- a/tiledb/ml/models/base.py +++ b/tiledb/ml/models/base.py @@ -2,6 +2,7 @@ import os import platform +import time from abc import ABC, abstractmethod from enum import Enum, unique from typing import Any, Generic, Mapping, Optional, Tuple, TypeVar @@ -26,6 +27,10 @@ class ModelFileProperties(Enum): TILEDB_ML_MODEL_PREVIEW = "TILEDB_ML_MODEL_PREVIEW" +def current_milli_time() -> int: + return round(time.time() * 1000) + + class TileDBModel(ABC, Generic[Model]): """ This is the base class for all TileDB model storage functionalities, i.e, diff --git a/tiledb/ml/models/pytorch.py b/tiledb/ml/models/pytorch.py index e1be85bd..8da3daef 100644 --- a/tiledb/ml/models/pytorch.py +++ b/tiledb/ml/models/pytorch.py @@ -9,7 +9,7 @@ import tiledb -from .base import Meta, TileDBModel, Timestamp +from .base import Meta, TileDBModel, Timestamp, current_milli_time class PyTorchTileDBModel(TileDBModel[torch.nn.Module]): @@ -102,6 +102,8 @@ def load( # type: ignore :param optimizer: A defined PyTorch optimizer. :return: A dictionary with attributes other than model or optimizer state_dict. """ + + # TODO: Change timestamp when issue in core is resolved model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) model_array_results = model_array[:] schema = model_array.schema @@ -211,7 +213,11 @@ def _write_array( optimizer state, extra model information) of a PyTorch model. :param meta: Extra metadata to save in a TileDB array. """ - with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb: + + # TODO: Change timestamp when issue in core is resolved + with tiledb.open( + self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx + ) as tf_model_tiledb: # Insertion in TileDB array tf_model_tiledb[:] = { key: np.array([value]) for key, value in serialized_model_dict.items() diff --git a/tiledb/ml/models/sklearn.py b/tiledb/ml/models/sklearn.py index 86e55d72..0fc8e527 100644 --- a/tiledb/ml/models/sklearn.py +++ b/tiledb/ml/models/sklearn.py @@ -10,7 +10,7 @@ import tiledb -from .base import Meta, TileDBModel, Timestamp +from .base import Meta, TileDBModel, Timestamp, current_milli_time class SklearnTileDBModel(TileDBModel[BaseEstimator]): @@ -47,6 +47,8 @@ def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator: in the specified time range. :return: A Sklearn model object. """ + # TODO: Change timestamp when issue in core is resolved + model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) model_array_results = model_array[:] model = pickle.loads(model_array_results["model_params"].item(0)) @@ -102,7 +104,11 @@ def _write_array(self, serialized_model: bytes, meta: Optional[Meta]) -> None: :param serialized_model: A pickled sklearn model. :param meta: Extra metadata to save in a TileDB array. """ - with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb: + # TODO: Change timestamp when issue in core is resolved + + with tiledb.open( + self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx + ) as tf_model_tiledb: # Insertion in TileDB array tf_model_tiledb[:] = {"model_params": np.array([serialized_model])} self.update_model_metadata(array=tf_model_tiledb, meta=meta) diff --git a/tiledb/ml/models/tensorflow_keras.py b/tiledb/ml/models/tensorflow_keras.py index 57cc9ebe..65b46069 100644 --- a/tiledb/ml/models/tensorflow_keras.py +++ b/tiledb/ml/models/tensorflow_keras.py @@ -20,7 +20,7 @@ import tiledb -from .base import Meta, TileDBModel, Timestamp +from .base import Meta, TileDBModel, Timestamp, current_milli_time class TensorflowKerasTileDBModel(TileDBModel[tf.keras.Model]): @@ -88,6 +88,8 @@ def load( :param input_shape: The shape that the custom model expects as input :return: Tensorflow model. """ + # TODO: Change timestamp when issue in core is resolved + with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array: model_array_results = model_array[:] model_config = json.loads(model_array.meta["model_config"]) @@ -257,7 +259,10 @@ def _write_array( ) -> None: """Write Tensorflow model to a TileDB array.""" assert self.model - with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb: + # TODO: Change timestamp when issue in core is resolved + with tiledb.open( + self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx + ) as tf_model_tiledb: if isinstance(self.model, (Functional, Sequential)): tf_model_tiledb[:] = { "model_weights": np.array([serialized_weights]),