Skip to content

Commit

Permalink
Moderate model refactoring (#199)
Browse files Browse the repository at this point in the history
* Inline _get_file_properties

* TileDBArtifact: Merge _write_model_metadata into _write_array

* TensorflowKerasTileDBModel: pass model_metadata to the _write_array() call

* Pass tensorboard_log_dir to _write_array

* Refactor _load_tensorboard to use _get_model_param

* _get_model_param: Fetch only the key attribute

* Open/close the tiledb array only once per load call
  • Loading branch information
gsakkis authored Dec 20, 2022
1 parent bf5644e commit 92f594d
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 313 deletions.
145 changes: 59 additions & 86 deletions tiledb/ml/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ def __init__(
self.ctx = ctx
self.artifact = artifact
self.uri = get_cloud_uri(uri, namespace) if namespace else uri
self._file_properties = self._get_file_properties()
self._file_properties = {
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK.value: self.Name,
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION.value: self.Version,
ModelFileProperties.TILEDB_ML_MODEL_STAGE.value: "STAGING",
ModelFileProperties.TILEDB_ML_MODEL_PYTHON_VERSION.value: platform.python_version(),
ModelFileProperties.TILEDB_ML_MODEL_PREVIEW.value: self.preview(),
ModelFileProperties.TILEDB_ML_MODEL_VERSION.value: __version__,
}

@abstractmethod
def save(self, *, update: bool = False, meta: Optional[Meta] = None) -> None:
Expand All @@ -88,34 +95,23 @@ def get_weights(self, timestamp: Optional[Timestamp] = None) -> Weights:
"""
Returns model's weights. Works for Tensorflow Keras and PyTorch
"""
return cast(Weights, self._get_model_param("model", timestamp))
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
return cast(Weights, self._get_model_param(model_array, "model"))

def get_optimizer_weights(self, timestamp: Optional[Timestamp] = None) -> Weights:
"""
Returns optimizer's weights. Works for Tensorflow Keras and PyTorch
"""
return cast(Weights, self._get_model_param("optimizer", timestamp))
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
return cast(Weights, self._get_model_param(model_array, "optimizer"))

@abstractmethod
def preview(self) -> str:
"""
Creates a string representation of a machine learning model.
"""

def _get_file_properties(self) -> Mapping[str, str]:
return {
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK.value: self.Name,
ModelFileProperties.TILEDB_ML_MODEL_ML_FRAMEWORK_VERSION.value: self.Version,
ModelFileProperties.TILEDB_ML_MODEL_STAGE.value: "STAGING",
ModelFileProperties.TILEDB_ML_MODEL_PYTHON_VERSION.value: platform.python_version(),
ModelFileProperties.TILEDB_ML_MODEL_PREVIEW.value: self.preview(),
ModelFileProperties.TILEDB_ML_MODEL_VERSION.value: __version__,
}

def _create_array(
self,
fields: Sequence[str],
) -> None:
def _create_array(self, fields: Sequence[str]) -> None:
"""Internal method that creates a TileDB array based on the model's spec."""

# The array will be be 1 dimensional with domain of 0 to max uint64. We use a tile extent of 1024 bytes
Expand Down Expand Up @@ -152,101 +148,78 @@ def _create_array(
if self.namespace:
update_file_properties(self.uri, self._file_properties)

def _write_array(self, model_params: Mapping[str, bytes]) -> None:
"""
Writes machine learning model related data, i.e., model weights, optimizer weights and Tensorboard files, to
a dense TileDB array.
"""
def _write_array(
self,
model_params: Mapping[str, bytes],
tensorboard_log_dir: Optional[str] = None,
meta: Optional[Meta] = None,
) -> None:
if tensorboard_log_dir:
tensorboard = self._serialize_tensorboard(tensorboard_log_dir)
else:
tensorboard = b""
model_params = dict(tensorboard=tensorboard, **model_params)

if meta is None:
meta = {}
if not meta.keys().isdisjoint(self._file_properties.keys()):
raise ValueError(
"Please avoid using file property key names as metadata keys!"
)

with tiledb.open(self.uri, "w", ctx=self.ctx) as model_array:
one_d_buffers = {}
max_len = 0

for key, value in model_params.items():
one_d_buffer = np.frombuffer(value, dtype=np.uint8)
one_d_buffer_len = len(one_d_buffer)
one_d_buffers[key] = one_d_buffer

# Write size only in case is greater than 0.
if one_d_buffer_len:
model_array.meta[key + "_size"] = one_d_buffer_len

if one_d_buffer_len > max_len:
max_len = one_d_buffer_len

model_array[0:max_len] = {
key: np.pad(value, (0, max_len - len(value)))
for key, value in one_d_buffers.items()
}

def _write_model_metadata(self, meta: Meta) -> None:
"""
Update the metadata in a TileDB model array. File properties also go in the metadata section.
:param meta: A mapping with the <key, value> pairs to be inserted in array's metadata.
"""
with tiledb.open(self.uri, "w", ctx=self.ctx) as model_array:
# Raise ValueError in case users provide metadata with the same keys as file properties.
if not meta.keys().isdisjoint(self._file_properties.keys()):
raise ValueError(
"Please avoid using file property key names as metadata keys!"
)

for key, value in meta.items():
model_array.meta[key] = value

for key, value in self._file_properties.items():
model_array.meta[key] = value
for mapping in meta, self._file_properties:
for key, value in mapping.items():
model_array.meta[key] = value

def _get_model_param(self, model_array: tiledb.Array, key: str) -> Any:
size_key = key + "_size"
try:
size = model_array.meta[size_key]
except KeyError:
raise Exception(
f"{size_key} metadata entry not present in {self.uri}"
f" (existing keys: {set(model_array.meta.keys())})"
)
return pickle.loads(model_array.query(attrs=(key,))[0:size][key].tobytes())

@staticmethod
def _serialize_tensorboard_files(log_dir: str) -> bytes:
def _serialize_tensorboard(log_dir: str) -> bytes:
"""Serialize all Tensorboard files."""

if not os.path.exists(log_dir):
raise ValueError(f"{log_dir} does not exist")

event_files = {}
tensorboard_files = {}
for path in glob.glob(f"{log_dir}/*tfevents*"):
with open(path, "rb") as f:
event_files[path] = f.read()
tensorboard_files[path] = f.read()
return pickle.dumps(tensorboard_files, protocol=4)

return pickle.dumps(event_files, protocol=4)

def _get_model_param(self, key: str, timestamp: Optional[Timestamp]) -> Any:
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
size_key = key + "_size"
try:
size = model_array.meta[size_key]
except KeyError:
raise Exception(
f"{size_key} metadata entry not present in {self.uri}"
f" (existing keys: {set(model_array.meta.keys())})"
)
return pickle.loads(model_array[0:size][key].tobytes())

def _load_tensorboard(self, timestamp: Optional[Timestamp] = None) -> None:
def _load_tensorboard(self, model_array: tiledb.Array) -> None:
"""
Writes Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
Write Tensorboard files to directory. Works for Tensorflow-Keras and PyTorch.
"""
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
try:
tensorboard_size = model_array.meta["tensorboard_size"]
except KeyError:
raise Exception(
f"tensorboard_size metadata entry not present in"
f" (existing keys: {set(model_array.meta.keys())})"
)

tb_contents = model_array[0:tensorboard_size]["tensorboard"]
tensorboard_files = pickle.loads(tb_contents.tobytes())

for path, file_bytes in tensorboard_files.items():
log_dir = os.path.dirname(path)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
with open(os.path.join(log_dir, os.path.basename(path)), "wb") as f:
f.write(file_bytes)

def _use_legacy_schema(self, timestamp: Optional[Timestamp]) -> bool:
tensorboard_files = self._get_model_param(model_array, "tensorboard")
for path, file_bytes in tensorboard_files.items():
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(file_bytes)

def _use_legacy_schema(self, model_array: tiledb.Array) -> bool:
# TODO: Decide based on tiledb-ml version and not on schema characteristics, like "offset".
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
return str(model_array.schema.domain.dim(0).name) != "offset"
return str(model_array.schema.domain.dim(0).name) != "offset"
55 changes: 13 additions & 42 deletions tiledb/ml/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ def save(
else:
serialized_optimizer_dict = b""

# Serialize Tensorboard files
if summary_writer:
tensorboard = self._serialize_tensorboard_files(
log_dir=summary_writer.log_dir
)
else:
tensorboard = b""

# Create TileDB model array
if not update:
self._create_array(fields=["model", "optimizer", "tensorboard"])
Expand All @@ -78,13 +70,11 @@ def save(
model_params={
"model": serialized_model_dict,
"optimizer": serialized_optimizer_dict,
"tensorboard": tensorboard,
}
},
tensorboard_log_dir=summary_writer.log_dir if summary_writer else None,
meta=meta,
)

if meta:
self._write_model_metadata(meta=meta)

def load(
self,
*,
Expand All @@ -102,29 +92,19 @@ def load(
:param callback: Boolean variable if True will store Callback data into saved directory
:return: A dictionary with attributes other than model or optimizer state_dict.
"""

load = (
self.__load_legacy
if self._use_legacy_schema(timestamp=timestamp)
else self.__load
)
return load(
model=model, optimizer=optimizer, timestamp=timestamp, callback=callback
)
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
if self._use_legacy_schema(model_array):
return self.__load_legacy(model_array, model, optimizer, callback)
else:
return self.__load(model_array, model, optimizer, callback)

def __load_legacy(
self,
model_array: tiledb.Array,
model: torch.nn.Module,
optimizer: Optimizer,
timestamp: Optional[Timestamp],
callback: bool,
) -> Optional[Mapping[str, Any]]:
"""
Load a PyTorch model from a TileDB array.
"""

# TODO: Change timestamp when issue in core is resolved
model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
model_array_results = model_array[:]
schema = model_array.schema

Expand Down Expand Up @@ -169,25 +149,16 @@ def __load_legacy(

def __load(
self,
model_array: tiledb.Array,
model: torch.nn.Module,
optimizer: Optimizer,
timestamp: Optional[Timestamp],
callback: bool,
) -> None:
"""
Load a PyTorch model from a TileDB array.
"""

model_state_dict = self.get_weights(timestamp=timestamp)
model.load_state_dict(model_state_dict)

# Load model's state dictionary
model.load_state_dict(self._get_model_param(model_array, "model"))
if optimizer:
opt_state_dict = self.get_optimizer_weights(timestamp=timestamp)
optimizer.load_state_dict(opt_state_dict)

optimizer.load_state_dict(self._get_model_param(model_array, "optimizer"))
if callback:
self._load_tensorboard(timestamp=timestamp)
self._load_tensorboard(model_array)

def preview(self) -> str:
"""
Expand Down
48 changes: 10 additions & 38 deletions tiledb/ml/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def save(self, *, update: bool = False, meta: Optional[Meta] = None) -> None:
if not update:
self._create_array(fields=["model"])

self._write_array(model_params={"model": serialized_model})

if meta:
self._write_model_metadata(meta=meta)
self._write_array(model_params={"model": serialized_model}, meta=meta)

def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator:
"""
Expand All @@ -62,42 +59,17 @@ def load(self, *, timestamp: Optional[Timestamp] = None) -> BaseEstimator:
in the specified time range.
:return: A Sklearn model object.
"""
# TODO: Change timestamp when issue in core is resolved

load = (
self.__load_legacy
if self._use_legacy_schema(timestamp=timestamp)
else self.__load
)
return load(timestamp=timestamp)

def __load_legacy(self, *, timestamp: Optional[Timestamp]) -> BaseEstimator:
"""
Load a Sklearn model from a TileDB array.
"""
model_array = tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp)
model_array_results = model_array[:]
model = pickle.loads(model_array_results["model_params"].item(0))
return model
with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
if self._use_legacy_schema(model_array):
return self.__load_legacy(model_array)
else:
return self.__load(model_array)

def __load(self, *, timestamp: Optional[Timestamp]) -> BaseEstimator:
"""
Load a Sklearn model from a TileDB array.
"""
def __load_legacy(self, model_array: tiledb.Array) -> BaseEstimator:
return pickle.loads(model_array[:]["model_params"].item(0))

with tiledb.open(self.uri, ctx=self.ctx, timestamp=timestamp) as model_array:
try:
model_size = model_array.meta["model_size"]
except KeyError:
raise Exception(
f"model_size metadata entry not present in {self.uri}"
f" (existing keys: {set(model_array.meta.keys())})"
)

model_contents = model_array[0:model_size]["model"]
model_bytes = model_contents.tobytes()

return pickle.loads(model_bytes)
def __load(self, model_array: tiledb.Array) -> BaseEstimator:
return self._get_model_param(model_array, "model")

def preview(self, *, display: str = "text") -> str:
"""
Expand Down
Loading

0 comments on commit 92f594d

Please sign in to comment.