Skip to content

Commit

Permalink
[Fix/Enhancement] Adding default argument current time timestamp in o…
Browse files Browse the repository at this point in the history
…pen 'w' mode (#89)

* Adding in every save operation current time timestamp
ktsitsi authored Nov 16, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent fc781c9 commit b44a50c
Showing 4 changed files with 28 additions and 6 deletions.
5 changes: 5 additions & 0 deletions tiledb/ml/models/base.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

import os
import platform
import time
from abc import ABC, abstractmethod
from enum import Enum, unique
from typing import Any, Generic, Mapping, Optional, Tuple, TypeVar
@@ -26,6 +27,10 @@ class ModelFileProperties(Enum):
TILEDB_ML_MODEL_PREVIEW = "TILEDB_ML_MODEL_PREVIEW"


def current_milli_time() -> int:
return round(time.time() * 1000)


class TileDBModel(ABC, Generic[Model]):
"""
This is the base class for all TileDB model storage functionalities, i.e,
10 changes: 8 additions & 2 deletions tiledb/ml/models/pytorch.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

import tiledb

from .base import Meta, TileDBModel, Timestamp
from .base import Meta, TileDBModel, Timestamp, current_milli_time


class PyTorchTileDBModel(TileDBModel[torch.nn.Module]):
@@ -102,6 +102,8 @@ def load( # type: ignore
:param optimizer: A defined PyTorch optimizer.
:return: A dictionary with attributes other than model or optimizer state_dict.
"""

# TODO: Change timestamp when issue in core is resolved
model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
model_array_results = model_array[:]
schema = model_array.schema
@@ -211,7 +213,11 @@ def _write_array(
optimizer state, extra model information) of a PyTorch model.
:param meta: Extra metadata to save in a TileDB array.
"""
with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb:

# TODO: Change timestamp when issue in core is resolved
with tiledb.open(
self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx
) as tf_model_tiledb:
# Insertion in TileDB array
tf_model_tiledb[:] = {
key: np.array([value]) for key, value in serialized_model_dict.items()
10 changes: 8 additions & 2 deletions tiledb/ml/models/sklearn.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@

import tiledb

from .base import Meta, TileDBModel, Timestamp
from .base import Meta, TileDBModel, Timestamp, current_milli_time


class SklearnTileDBModel(TileDBModel[BaseEstimator]):
@@ -47,6 +47,8 @@ def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator:
in the specified time range.
:return: A Sklearn model object.
"""
# TODO: Change timestamp when issue in core is resolved

model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
model_array_results = model_array[:]
model = pickle.loads(model_array_results["model_params"].item(0))
@@ -102,7 +104,11 @@ def _write_array(self, serialized_model: bytes, meta: Optional[Meta]) -> None:
:param serialized_model: A pickled sklearn model.
:param meta: Extra metadata to save in a TileDB array.
"""
with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb:
# TODO: Change timestamp when issue in core is resolved

with tiledb.open(
self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx
) as tf_model_tiledb:
# Insertion in TileDB array
tf_model_tiledb[:] = {"model_params": np.array([serialized_model])}
self.update_model_metadata(array=tf_model_tiledb, meta=meta)
9 changes: 7 additions & 2 deletions tiledb/ml/models/tensorflow_keras.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@

import tiledb

from .base import Meta, TileDBModel, Timestamp
from .base import Meta, TileDBModel, Timestamp, current_milli_time


class TensorflowKerasTileDBModel(TileDBModel[tf.keras.Model]):
@@ -88,6 +88,8 @@ def load(
:param input_shape: The shape that the custom model expects as input
:return: Tensorflow model.
"""
# TODO: Change timestamp when issue in core is resolved

with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
model_array_results = model_array[:]
model_config = json.loads(model_array.meta["model_config"])
@@ -257,7 +259,10 @@ def _write_array(
) -> None:
"""Write Tensorflow model to a TileDB array."""
assert self.model
with tiledb.open(self.uri, "w", ctx=self.ctx) as tf_model_tiledb:
# TODO: Change timestamp when issue in core is resolved
with tiledb.open(
self.uri, "w", timestamp=current_milli_time(), ctx=self.ctx
) as tf_model_tiledb:
if isinstance(self.model, (Functional, Sequential)):
tf_model_tiledb[:] = {
"model_weights": np.array([serialized_weights]),

0 comments on commit b44a50c

Please sign in to comment.