Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions threaded_estimator/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
""" Module to expose trained models for inference."""

import tensorflow as tf

from tensorflow.contrib.learn import RunConfig
from queue import Queue
from threading import Thread

import tensorflow as tf
from tensorflow.contrib.learn import RunConfig

from threaded_estimator import iris_data


Expand Down Expand Up @@ -35,9 +35,10 @@ def __init__(self, model_path='./trained_models/',

self.model_path = model_path

self.estimator = self.load_estimator()
self.estimator: tf.estimator.Estimator = self.load_estimator()

self.verbose = verbose
self.saved_model = None

def predict(self, features):
"""
Expand Down Expand Up @@ -284,3 +285,69 @@ def queued_predict_input_fn(self):
'PetalWidth': tf.float32})

return dataset


class FlowerClassifierGenerator(FlowerClassifier):

def __init__(self, model_path='./trained_models/',
verbose=False):
super(FlowerClassifierGenerator, self).__init__(model_path=model_path,
verbose=verbose)
self.next_features = None
self.prediction = self.estimator.predict(input_fn=self.generator_predict_input_fn)

def generator(self):
"""

Yield
-------
features: dict
dict of input features, containing keys 'SepalLength'
'SepalWidth'
'PetalLength'
'PetalWidth'
"""
while True:
yield self.next_features

def predict(self, features):
"""
Overwrites .predict in FlowerClassifierBasic.


Parameters
----------
features: dict
dict of input features, containing keys 'SepalLength'
'SepalWidth'
'PetalLength'
'PetalWidth'

Yield
-------
predictions: dict
Dictionary containing 'probs'
'outputs'
'predicted_class'

"""
self.next_features = features
keys = list(features.keys())
for _ in range(len(features[keys[0]])):
yield next(self.prediction)

def generator_predict_input_fn(self):
"""
Construct a tf.data.Dataset from a generator

Return
-------
dataset: tf.data.Dataset
"""
dataset = tf.data.Dataset.from_generator(self.generator,
output_types={'SepalLength': tf.float32,
'SepalWidth': tf.float32,
'PetalLength': tf.float32,
'PetalWidth': tf.float32})

return dataset
28 changes: 28 additions & 0 deletions threaded_estimator/tests/test_flower_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,33 @@ def test_threaded_faster_than_non_threaded():
f'Threaded was {unthreaded_time/threaded_time} times faster!')


def test_generator_faster_than_threaded():

fe_threaded = models.FlowerClassifierThreaded(threaded=True)
fe_generator = models.FlowerClassifierGenerator()

n_epochs = 1000

print('starting threaded')
t1 = time.time()
for _ in range(n_epochs):
predictions = list(fe_threaded.predict(features=predict_x))

print('starting generator')
t2 = time.time()
for _ in range(n_epochs):
predictions = list(fe_generator.predict(features=predict_x))

t3 = time.time()

threaded_time = (t2-t1)
generator_time = (t3-t2)

assert generator_time < threaded_time

print(f'Threaded time was {threaded_time}; s\n'
f'Generator time was {generator_time}; s\n'
f'Generator was {threaded_time/generator_time} times faster!')