-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from dcos-labs/example_code
Moved example code into repo.
- Loading branch information
Showing
20 changed files
with
3,518 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# TensorFlow Example Scripts | ||
|
||
These examples contains TensorFlow scripts for testing/demo purposes. These models are not intended for any other use cases. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Alexnet model configuration. | ||
References: | ||
Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton | ||
ImageNet Classification with Deep Convolutional Neural Networks | ||
Advances in Neural Information Processing Systems. 2012 | ||
""" | ||
|
||
import model | ||
|
||
|
||
class AlexnetModel(model.Model): | ||
"""Alexnet cnn model.""" | ||
|
||
def __init__(self): | ||
super(AlexnetModel, self).__init__('alexnet', 224 + 3, 512, 0.005) | ||
|
||
def add_inference(self, cnn): | ||
# Note: VALID requires padding the images by 3 in width and height | ||
cnn.conv(64, 11, 11, 4, 4, 'VALID') | ||
cnn.mpool(3, 3, 2, 2) | ||
cnn.conv(192, 5, 5) | ||
cnn.mpool(3, 3, 2, 2) | ||
cnn.conv(384, 3, 3) | ||
cnn.conv(384, 3, 3) | ||
cnn.conv(256, 3, 3) | ||
cnn.mpool(3, 3, 2, 2) | ||
cnn.reshape([-1, 256 * 6 * 6]) | ||
cnn.affine(4096) | ||
cnn.dropout() | ||
cnn.affine(4096) | ||
cnn.dropout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Provides ways to store benchmark output.""" | ||
|
||
|
||
def store_benchmark(data, storage_type=None): | ||
"""Store benchmark data. | ||
Args: | ||
data: Dictionary mapping from string benchmark name to | ||
numeric benchmark value. | ||
storage_type: (string) Specifies where to store benchmark | ||
result. If storage_type is | ||
'cbuild_benchmark_datastore': store outputs in our continuous | ||
build datastore. gcloud must be setup in current environment | ||
pointing to the project where data will be added. | ||
""" | ||
if storage_type == 'cbuild_benchmark_datastore': | ||
try: | ||
# pylint: disable=g-import-not-at-top | ||
import cbuild_benchmark_storage | ||
# pylint: enable=g-import-not-at-top | ||
except ImportError: | ||
raise ImportError( | ||
'Missing cbuild_benchmark_storage.py required for ' | ||
'benchmark_cloud_datastore option') | ||
cbuild_benchmark_storage.upload_to_benchmark_datastore(data) | ||
else: | ||
assert False, 'unknown storage_type: ' + storage_type |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Provides a way to store benchmark results in GCE Datastore. | ||
Datastore client is initialized from current environment. | ||
Data is stored using the format defined in: | ||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/test/upload_test_benchmarks_index.yaml | ||
""" | ||
from datetime import datetime | ||
import json | ||
import os | ||
import sys | ||
from google.cloud import datastore | ||
|
||
|
||
_TEST_NAME_ENV_VAR = 'TF_DIST_BENCHMARK_NAME' | ||
|
||
|
||
def upload_to_benchmark_datastore(data, test_name=None, start_time=None): | ||
"""Use a new datastore.Client to upload data to datastore. | ||
Create the datastore Entities from that data and upload them to the | ||
datastore in a batch using the client connection. | ||
Args: | ||
data: Map from benchmark names to values. | ||
test_name: Name of this test. If not specified, name will be set either | ||
from TF_DIST_BENCHMARK_NAME environment variable or to default name | ||
'TestBenchmark'. | ||
start_time: (datetime) Time to record for this test. | ||
Raises: | ||
ValueError: if test_name is not passed in and TF_DIST_BENCHMARK_NAME | ||
is not set. | ||
""" | ||
client = datastore.Client() | ||
|
||
if not test_name: | ||
if _TEST_NAME_ENV_VAR in os.environ: | ||
test_name = os.environ[_TEST_NAME_ENV_VAR] | ||
else: | ||
raise ValueError( | ||
'No test name passed in for benchmarks. ' | ||
'Either pass a test_name to upload_to_benchmark_datastore or ' | ||
'set %s environment variable.' % _TEST_NAME_ENV_VAR) | ||
test_name = unicode(test_name) | ||
|
||
if not start_time: | ||
start_time = datetime.now() | ||
|
||
# Create one Entry Entity for each benchmark entry. The wall-clock timing is | ||
# the attribute to be fetched and displayed. The full entry information is | ||
# also stored as a non-indexed JSON blob. | ||
entries = [] | ||
batch = [] | ||
for name, value in data.items(): | ||
e_key = client.key('Entry') | ||
e_val = datastore.Entity(e_key, exclude_from_indexes=['info']) | ||
entry_map = {'name': name, 'wallTime': value, 'iters': '1'} | ||
entries.append(entry_map) | ||
e_val.update({ | ||
'test': test_name, | ||
'start': start_time, | ||
'entry': unicode(name), | ||
'timing': value, | ||
'info': unicode(json.dumps(entry_map)) | ||
}) | ||
batch.append(e_val) | ||
|
||
# Create the Test Entity containing all the test information as a | ||
# non-indexed JSON blob. | ||
test_result = json.dumps( | ||
{'name': test_name, | ||
'startTime': (start_time - datetime(1970, 1, 1)).total_seconds(), | ||
'entries': {'entry': entries}, | ||
'runConfiguration': {'argument': sys.argv[1:]}}) | ||
t_key = client.key('Test') | ||
t_val = datastore.Entity(t_key, exclude_from_indexes=['info']) | ||
t_val.update({ | ||
'test': test_name, | ||
'start': start_time, | ||
'info': unicode(test_result) | ||
}) | ||
batch.append(t_val) | ||
|
||
# Put the whole batch of Entities in the datastore. | ||
client.put_multi(batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Utilities for CNN benchmarks.""" | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def tensorflow_version_tuple(): | ||
v = tf.__version__ | ||
major, minor, patch = v.split('.') | ||
return (int(major), int(minor), patch) | ||
|
||
|
||
def tensorflow_version(): | ||
vt = tensorflow_version_tuple() | ||
return vt[0] * 1000 + vt[1] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Benchmark dataset utilities. | ||
""" | ||
|
||
from abc import abstractmethod | ||
import os | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class Dataset(object): | ||
"""Abstract class for cnn benchmarks dataset.""" | ||
|
||
def __init__(self, name, data_dir=None): | ||
self.name = name | ||
if data_dir is None: | ||
raise ValueError('Data directory not specified') | ||
self.data_dir = data_dir | ||
|
||
def tf_record_pattern(self, subset): | ||
return os.path.join(self.data_dir, '%s-*-of-*' % subset) | ||
|
||
def reader(self): | ||
return tf.TFRecordReader() | ||
|
||
@abstractmethod | ||
def num_classes(self): | ||
pass | ||
|
||
@abstractmethod | ||
def num_examples_per_epoch(self, subset): | ||
pass | ||
|
||
def __str__(self): | ||
return self.name | ||
|
||
|
||
class FlowersData(Dataset): | ||
|
||
def __init__(self, data_dir=None): | ||
super(FlowersData, self).__init__('Flowers', data_dir) | ||
|
||
def num_classes(self): | ||
return 5 | ||
|
||
def num_examples_per_epoch(self, subset): | ||
if subset == 'train': | ||
return 3170 | ||
elif subset == 'validation': | ||
return 500 | ||
else: | ||
raise ValueError('Invalid data subset "%s"' % subset) | ||
|
||
|
||
class ImagenetData(Dataset): | ||
|
||
def __init__(self, data_dir=None): | ||
super(ImagenetData, self).__init__('ImageNet', data_dir) | ||
|
||
def num_classes(self): | ||
return 1000 | ||
|
||
def num_examples_per_epoch(self, subset='train'): | ||
if subset == 'train': | ||
return 1281167 | ||
elif subset == 'validation': | ||
return 50000 | ||
else: | ||
raise ValueError('Invalid data subset "%s"' % subset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
"""Googlenet model configuration. | ||
References: | ||
Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, | ||
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich | ||
Going deeper with convolutions | ||
arXiv preprint arXiv:1409.4842 (2014) | ||
""" | ||
|
||
import model | ||
|
||
|
||
class GooglenetModel(model.Model): | ||
|
||
def __init__(self): | ||
super(GooglenetModel, self).__init__('googlenet', 224, 32, 0.005) | ||
|
||
def add_inference(self, cnn): | ||
def inception_v1(cnn, k, l, m, n, p, q): | ||
cols = [[('conv', k, 1, 1)], [('conv', l, 1, 1), ('conv', m, 3, 3)], | ||
[('conv', n, 1, 1), ('conv', p, 5, 5)], | ||
[('mpool', 3, 3, 1, 1, 'SAME'), ('conv', q, 1, 1)]] | ||
cnn.inception_module('incept_v1', cols) | ||
|
||
cnn.conv(64, 7, 7, 2, 2) | ||
cnn.mpool(3, 3, 2, 2, mode='SAME') | ||
cnn.conv(64, 1, 1) | ||
cnn.conv(192, 3, 3) | ||
cnn.mpool(3, 3, 2, 2, mode='SAME') | ||
inception_v1(cnn, 64, 96, 128, 16, 32, 32) | ||
inception_v1(cnn, 128, 128, 192, 32, 96, 64) | ||
cnn.mpool(3, 3, 2, 2, mode='SAME') | ||
inception_v1(cnn, 192, 96, 208, 16, 48, 64) | ||
inception_v1(cnn, 160, 112, 224, 24, 64, 64) | ||
inception_v1(cnn, 128, 128, 256, 24, 64, 64) | ||
inception_v1(cnn, 112, 144, 288, 32, 64, 64) | ||
inception_v1(cnn, 256, 160, 320, 32, 128, 128) | ||
cnn.mpool(3, 3, 2, 2, mode='SAME') | ||
inception_v1(cnn, 256, 160, 320, 32, 128, 128) | ||
inception_v1(cnn, 384, 192, 384, 48, 128, 128) | ||
cnn.apool(7, 7, 1, 1, mode='VALID') | ||
cnn.reshape([-1, 1024]) |
Oops, something went wrong.