diff --git a/threaded_estimator/models.py b/threaded_estimator/models.py index fc99179..00325b7 100644 --- a/threaded_estimator/models.py +++ b/threaded_estimator/models.py @@ -1,11 +1,11 @@ """ Module to expose trained models for inference.""" -import tensorflow as tf - -from tensorflow.contrib.learn import RunConfig from queue import Queue from threading import Thread +import tensorflow as tf +from tensorflow.contrib.learn import RunConfig + from threaded_estimator import iris_data @@ -35,9 +35,10 @@ def __init__(self, model_path='./trained_models/', self.model_path = model_path - self.estimator = self.load_estimator() + self.estimator: tf.estimator.Estimator = self.load_estimator() self.verbose = verbose + self.saved_model = None def predict(self, features): """ @@ -284,3 +285,69 @@ def queued_predict_input_fn(self): 'PetalWidth': tf.float32}) return dataset + + +class FlowerClassifierGenerator(FlowerClassifier): + + def __init__(self, model_path='./trained_models/', + verbose=False): + super(FlowerClassifierGenerator, self).__init__(model_path=model_path, + verbose=verbose) + self.next_features = None + self.prediction = self.estimator.predict(input_fn=self.generator_predict_input_fn) + + def generator(self): + """ + + Yield + ------- + features: dict + dict of input features, containing keys 'SepalLength' + 'SepalWidth' + 'PetalLength' + 'PetalWidth' + """ + while True: + yield self.next_features + + def predict(self, features): + """ + Overwrites .predict in FlowerClassifierBasic. + + + Parameters + ---------- + features: dict + dict of input features, containing keys 'SepalLength' + 'SepalWidth' + 'PetalLength' + 'PetalWidth' + + Yield + ------- + predictions: dict + Dictionary containing 'probs' + 'outputs' + 'predicted_class' + + """ + self.next_features = features + keys = list(features.keys()) + for _ in range(len(features[keys[0]])): + yield next(self.prediction) + + def generator_predict_input_fn(self): + """ + Construct a tf.data.Dataset from a generator + + Return + ------- + dataset: tf.data.Dataset + """ + dataset = tf.data.Dataset.from_generator(self.generator, + output_types={'SepalLength': tf.float32, + 'SepalWidth': tf.float32, + 'PetalLength': tf.float32, + 'PetalWidth': tf.float32}) + + return dataset diff --git a/threaded_estimator/tests/test_flower_estimator.py b/threaded_estimator/tests/test_flower_estimator.py index ac8334a..97ad931 100644 --- a/threaded_estimator/tests/test_flower_estimator.py +++ b/threaded_estimator/tests/test_flower_estimator.py @@ -85,5 +85,33 @@ def test_threaded_faster_than_non_threaded(): f'Threaded was {unthreaded_time/threaded_time} times faster!') +def test_generator_faster_than_threaded(): + + fe_threaded = models.FlowerClassifierThreaded(threaded=True) + fe_generator = models.FlowerClassifierGenerator() + + n_epochs = 1000 + + print('starting threaded') + t1 = time.time() + for _ in range(n_epochs): + predictions = list(fe_threaded.predict(features=predict_x)) + + print('starting generator') + t2 = time.time() + for _ in range(n_epochs): + predictions = list(fe_generator.predict(features=predict_x)) + + t3 = time.time() + + threaded_time = (t2-t1) + generator_time = (t3-t2) + + assert generator_time < threaded_time + + print(f'Threaded time was {threaded_time}; s\n' + f'Generator time was {generator_time}; s\n' + f'Generator was {threaded_time/generator_time} times faster!') +