Skip to content

Commit

Permalink
Merge pull request #9 from dcos-labs/example_code
Browse files Browse the repository at this point in the history
Moved example code into repo.
  • Loading branch information
Joerg Schad authored Oct 29, 2017
2 parents b0ec496 + 380d80b commit f3dcb5d
Show file tree
Hide file tree
Showing 20 changed files with 3,518 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tf_examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TensorFlow Example Scripts

These examples contains TensorFlow scripts for testing/demo purposes. These models are not intended for any other use cases.
Empty file.
47 changes: 47 additions & 0 deletions tf_examples/benchmarks/alexnet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Alexnet model configuration.
References:
Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton
ImageNet Classification with Deep Convolutional Neural Networks
Advances in Neural Information Processing Systems. 2012
"""

import model


class AlexnetModel(model.Model):
"""Alexnet cnn model."""

def __init__(self):
super(AlexnetModel, self).__init__('alexnet', 224 + 3, 512, 0.005)

def add_inference(self, cnn):
# Note: VALID requires padding the images by 3 in width and height
cnn.conv(64, 11, 11, 4, 4, 'VALID')
cnn.mpool(3, 3, 2, 2)
cnn.conv(192, 5, 5)
cnn.mpool(3, 3, 2, 2)
cnn.conv(384, 3, 3)
cnn.conv(384, 3, 3)
cnn.conv(256, 3, 3)
cnn.mpool(3, 3, 2, 2)
cnn.reshape([-1, 256 * 6 * 6])
cnn.affine(4096)
cnn.dropout()
cnn.affine(4096)
cnn.dropout()
41 changes: 41 additions & 0 deletions tf_examples/benchmarks/benchmark_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides ways to store benchmark output."""


def store_benchmark(data, storage_type=None):
"""Store benchmark data.
Args:
data: Dictionary mapping from string benchmark name to
numeric benchmark value.
storage_type: (string) Specifies where to store benchmark
result. If storage_type is
'cbuild_benchmark_datastore': store outputs in our continuous
build datastore. gcloud must be setup in current environment
pointing to the project where data will be added.
"""
if storage_type == 'cbuild_benchmark_datastore':
try:
# pylint: disable=g-import-not-at-top
import cbuild_benchmark_storage
# pylint: enable=g-import-not-at-top
except ImportError:
raise ImportError(
'Missing cbuild_benchmark_storage.py required for '
'benchmark_cloud_datastore option')
cbuild_benchmark_storage.upload_to_benchmark_datastore(data)
else:
assert False, 'unknown storage_type: ' + storage_type
99 changes: 99 additions & 0 deletions tf_examples/benchmarks/cbuild_benchmark_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides a way to store benchmark results in GCE Datastore.
Datastore client is initialized from current environment.
Data is stored using the format defined in:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/test/upload_test_benchmarks_index.yaml
"""
from datetime import datetime
import json
import os
import sys
from google.cloud import datastore


_TEST_NAME_ENV_VAR = 'TF_DIST_BENCHMARK_NAME'


def upload_to_benchmark_datastore(data, test_name=None, start_time=None):
"""Use a new datastore.Client to upload data to datastore.
Create the datastore Entities from that data and upload them to the
datastore in a batch using the client connection.
Args:
data: Map from benchmark names to values.
test_name: Name of this test. If not specified, name will be set either
from TF_DIST_BENCHMARK_NAME environment variable or to default name
'TestBenchmark'.
start_time: (datetime) Time to record for this test.
Raises:
ValueError: if test_name is not passed in and TF_DIST_BENCHMARK_NAME
is not set.
"""
client = datastore.Client()

if not test_name:
if _TEST_NAME_ENV_VAR in os.environ:
test_name = os.environ[_TEST_NAME_ENV_VAR]
else:
raise ValueError(
'No test name passed in for benchmarks. '
'Either pass a test_name to upload_to_benchmark_datastore or '
'set %s environment variable.' % _TEST_NAME_ENV_VAR)
test_name = unicode(test_name)

if not start_time:
start_time = datetime.now()

# Create one Entry Entity for each benchmark entry. The wall-clock timing is
# the attribute to be fetched and displayed. The full entry information is
# also stored as a non-indexed JSON blob.
entries = []
batch = []
for name, value in data.items():
e_key = client.key('Entry')
e_val = datastore.Entity(e_key, exclude_from_indexes=['info'])
entry_map = {'name': name, 'wallTime': value, 'iters': '1'}
entries.append(entry_map)
e_val.update({
'test': test_name,
'start': start_time,
'entry': unicode(name),
'timing': value,
'info': unicode(json.dumps(entry_map))
})
batch.append(e_val)

# Create the Test Entity containing all the test information as a
# non-indexed JSON blob.
test_result = json.dumps(
{'name': test_name,
'startTime': (start_time - datetime(1970, 1, 1)).total_seconds(),
'entries': {'entry': entries},
'runConfiguration': {'argument': sys.argv[1:]}})
t_key = client.key('Test')
t_val = datastore.Entity(t_key, exclude_from_indexes=['info'])
t_val.update({
'test': test_name,
'start': start_time,
'info': unicode(test_result)
})
batch.append(t_val)

# Put the whole batch of Entities in the datastore.
client.put_multi(batch)
30 changes: 30 additions & 0 deletions tf_examples/benchmarks/cnn_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utilities for CNN benchmarks."""

import tensorflow as tf


def tensorflow_version_tuple():
v = tf.__version__
major, minor, patch = v.split('.')
return (int(major), int(minor), patch)


def tensorflow_version():
vt = tensorflow_version_tuple()
return vt[0] * 1000 + vt[1]

83 changes: 83 additions & 0 deletions tf_examples/benchmarks/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Benchmark dataset utilities.
"""

from abc import abstractmethod
import os

import tensorflow as tf


class Dataset(object):
"""Abstract class for cnn benchmarks dataset."""

def __init__(self, name, data_dir=None):
self.name = name
if data_dir is None:
raise ValueError('Data directory not specified')
self.data_dir = data_dir

def tf_record_pattern(self, subset):
return os.path.join(self.data_dir, '%s-*-of-*' % subset)

def reader(self):
return tf.TFRecordReader()

@abstractmethod
def num_classes(self):
pass

@abstractmethod
def num_examples_per_epoch(self, subset):
pass

def __str__(self):
return self.name


class FlowersData(Dataset):

def __init__(self, data_dir=None):
super(FlowersData, self).__init__('Flowers', data_dir)

def num_classes(self):
return 5

def num_examples_per_epoch(self, subset):
if subset == 'train':
return 3170
elif subset == 'validation':
return 500
else:
raise ValueError('Invalid data subset "%s"' % subset)


class ImagenetData(Dataset):

def __init__(self, data_dir=None):
super(ImagenetData, self).__init__('ImageNet', data_dir)

def num_classes(self):
return 1000

def num_examples_per_epoch(self, subset='train'):
if subset == 'train':
return 1281167
elif subset == 'validation':
return 50000
else:
raise ValueError('Invalid data subset "%s"' % subset)
57 changes: 57 additions & 0 deletions tf_examples/benchmarks/googlenet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Googlenet model configuration.
References:
Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich
Going deeper with convolutions
arXiv preprint arXiv:1409.4842 (2014)
"""

import model


class GooglenetModel(model.Model):

def __init__(self):
super(GooglenetModel, self).__init__('googlenet', 224, 32, 0.005)

def add_inference(self, cnn):
def inception_v1(cnn, k, l, m, n, p, q):
cols = [[('conv', k, 1, 1)], [('conv', l, 1, 1), ('conv', m, 3, 3)],
[('conv', n, 1, 1), ('conv', p, 5, 5)],
[('mpool', 3, 3, 1, 1, 'SAME'), ('conv', q, 1, 1)]]
cnn.inception_module('incept_v1', cols)

cnn.conv(64, 7, 7, 2, 2)
cnn.mpool(3, 3, 2, 2, mode='SAME')
cnn.conv(64, 1, 1)
cnn.conv(192, 3, 3)
cnn.mpool(3, 3, 2, 2, mode='SAME')
inception_v1(cnn, 64, 96, 128, 16, 32, 32)
inception_v1(cnn, 128, 128, 192, 32, 96, 64)
cnn.mpool(3, 3, 2, 2, mode='SAME')
inception_v1(cnn, 192, 96, 208, 16, 48, 64)
inception_v1(cnn, 160, 112, 224, 24, 64, 64)
inception_v1(cnn, 128, 128, 256, 24, 64, 64)
inception_v1(cnn, 112, 144, 288, 32, 64, 64)
inception_v1(cnn, 256, 160, 320, 32, 128, 128)
cnn.mpool(3, 3, 2, 2, mode='SAME')
inception_v1(cnn, 256, 160, 320, 32, 128, 128)
inception_v1(cnn, 384, 192, 384, 48, 128, 128)
cnn.apool(7, 7, 1, 1, mode='VALID')
cnn.reshape([-1, 1024])
Loading

0 comments on commit f3dcb5d

Please sign in to comment.