Skip to content
This repository has been archived by the owner on Aug 30, 2023. It is now read-only.

Commit

Permalink
Merge pull request #27 from mlgill/memory_mapped_data_loading
Browse files Browse the repository at this point in the history
Memory mapped data loading
  • Loading branch information
gessulat authored Nov 19, 2019
2 parents b7503d8 + 93a4a07 commit ce93f3a
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion prosit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import io
from . import io_local
from . import constants
from . import model
from . import alignment
Expand Down
29 changes: 29 additions & 0 deletions prosit/io_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from . import utils


def get_array(tensor, keys):
utils.check_mandatory_keys(tensor, keys)
return [tensor[key] for key in keys]


def to_hdf5(dictionary, path):
import h5py

with h5py.File(path, "w") as f:
for key, data in dictionary.items():
f.create_dataset(key, data=data, dtype=data.dtype, compression="gzip")


def from_hdf5(path, n_samples=None):
from keras.utils import HDF5Matrix
import h5py

# Get a list of the keys for the datasets
with h5py.File(path, 'r') as f:
dataset_list = list(f.keys())

# Assemble into a dictionary
data = dict()
for dataset in dataset_list:
data[dataset] = HDF5Matrix(path, dataset, start=0, end=n_samples, normalizer=None)
return data
4 changes: 2 additions & 2 deletions prosit/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import numpy as np

from . import model as model_lib
from . import io
from . import io_local
from . import constants
from . import sanitize


def predict(data, d_model):
# check for mandatory keys
x = io.get_array(data, d_model["config"]["x"])
x = io_local.get_array(data, d_model["config"]["x"])

keras.backend.set_session(d_model["session"])
with d_model["graph"].as_default():
Expand Down
2 changes: 1 addition & 1 deletion prosit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tensorflow as tf

from . import model
from . import io
from . import io_local
from . import constants
from . import tensorize
from . import prediction
Expand Down
8 changes: 4 additions & 4 deletions prosit/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from . import io
from . import io_local
from . import losses
from . import model as model_lib
from . import constants
Expand Down Expand Up @@ -28,8 +28,8 @@ def train(tensor, model, model_config, callbacks):
else:
loss = losses.get(model_config["loss"])
optimizer = model_config["optimizer"]
x = io.get_array(tensor, model_config["x"])
y = io.get_array(tensor, model_config["y"])
x = io_local.get_array(tensor, model_config["x"])
y = io_local.get_array(tensor, model_config["y"])
model.compile(optimizer=optimizer, loss=loss)
model.fit(
x=x,
Expand All @@ -48,6 +48,6 @@ def train(tensor, model, model_config, callbacks):
model_dir = constants.MODEL_DIR

model, model_config = model_lib.load(model_dir, trained=True)
tensor = io.from_hdf5(data_path)
tensor = io_local.from_hdf5(data_path)
callbacks = get_callbacks(model_dir)
train(tensor, model, model_config, callbacks)

0 comments on commit ce93f3a

Please sign in to comment.