-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from xldrx/monkey-patching
Add Monkey Patching
- Loading branch information
Showing
10 changed files
with
183 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#! /usr/bin/env python -u | ||
# coding=utf-8 | ||
|
||
# Using tracing server with TesorFlow Estimator API | ||
|
||
__author__ = 'Sayed Hadi Hashemi' | ||
|
||
import tensorflow as tf | ||
|
||
import tftracer | ||
tftracer.hook_inject() | ||
|
||
import numpy as np | ||
|
||
INPUT_SIZE = (299, 299, 3) | ||
MINIBATCH_SIZE = 128 | ||
NUM_CLASSES = 1000 | ||
NUM_STEPS = 500 | ||
|
||
|
||
def input_fn(): | ||
dataset = tf.data.Dataset.from_tensor_slices([0]).repeat(MINIBATCH_SIZE) | ||
dataset = dataset.map( | ||
lambda _: | ||
( | ||
{"x": np.random.uniform(size=INPUT_SIZE)}, | ||
[np.random.random_integers(0, NUM_CLASSES)] | ||
) | ||
) | ||
dataset = dataset.repeat(NUM_STEPS).batch(MINIBATCH_SIZE) | ||
return dataset | ||
|
||
|
||
def main(): | ||
estimator = tf.estimator.DNNClassifier( | ||
hidden_units=[10] * 150, | ||
feature_columns=[tf.feature_column.numeric_column("x", shape=INPUT_SIZE)], | ||
n_classes=NUM_CLASSES, | ||
) | ||
estimator.train(input_fn) | ||
estimator.evaluate(input_fn) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#! /usr/bin/env python -u | ||
# coding=utf-8 | ||
|
||
__author__ = 'Sayed Hadi Hashemi' | ||
|
||
|
||
def __add_tracing_server_hook(hooks): | ||
if hooks is None: | ||
return [hook_inject.__tracing_server.hook] | ||
else: | ||
hooks = list(hooks) | ||
hooks.append(hook_inject.__tracing_server.hook) | ||
return hooks | ||
|
||
|
||
def __new_init(*args, **kwargs): | ||
if hook_inject.__original_init is None: | ||
return | ||
|
||
if "hooks" in hook_inject.__original_init.__code__.co_varnames: | ||
hooks_index = hook_inject.__original_init.__code__.co_varnames.index("hooks") | ||
if len(args) > hooks_index: | ||
args = list(args) | ||
args[hooks_index] = __add_tracing_server_hook(args[hooks_index]) | ||
else: | ||
kwargs["hooks"] = __add_tracing_server_hook(kwargs.get("hooks", None)) | ||
else: | ||
print("'hooks' not in '_MonitoredSession'") | ||
|
||
hook_inject.__original_init(*args, **kwargs) | ||
|
||
|
||
def hook_inject(*args, **kwargs): | ||
""" | ||
(Experimental) Injects a tracing server hook to all instances of ``MonitoredSession`` by by monkey patching | ||
the initializer. This function is an alternative to adding `hooks` to estimator or sessions. | ||
Be aware, monkey patching could cause unexpected errors and is not recommended. | ||
This function should be called once in the main script preferably before importing anything else. | ||
Example: | ||
.. code-block:: python | ||
import tftracer | ||
tftracer.hook_inject() | ||
... | ||
estimator.train(input_fn) | ||
Args: | ||
**kwargs: same as :class:`tftracer.TracingServer`. | ||
Note: | ||
Monkey Patching (as :class:`tftracer.TracingServer`) works only with subclasses of ``MonitoredSession``. | ||
For other ``Session`` types, use :class:`tftracer.Timeline`. | ||
""" | ||
from . import TracingServer | ||
from tensorflow.python.training.monitored_session import _MonitoredSession | ||
|
||
if hook_inject.__original_init is None: | ||
hook_inject.__original_init = _MonitoredSession.__init__ | ||
hook_inject.__tracing_server = TracingServer(*args, **kwargs) | ||
_MonitoredSession.__init__ = __new_init | ||
|
||
|
||
hook_inject.__tracing_server = None | ||
hook_inject.__original_init = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#! /usr/bin/env python -u | ||
# coding=utf-8 | ||
|
||
__author__ = 'Sayed Hadi Hashemi' | ||
|
||
__version__ = '1.1.0' |