23
23
24
24
import tensorflow as tf
25
25
26
- _DEFAULT_CUDA_VERISON = "10.1"
27
- _DEFAULT_CUDNN_VERSION = "7"
28
26
_TFA_BAZELRC = ".bazelrc"
29
27
30
28
31
29
# Writes variables to bazelrc file
32
- def write_to_bazelrc (line ):
30
+ def write (line ):
33
31
with open (_TFA_BAZELRC , "a" ) as f :
34
32
f .write (line + "\n " )
35
33
36
34
37
- def write_action_env_to_bazelrc (var_name , var ):
38
- write_to_bazelrc ('build --action_env %s="%s"' % (var_name , str ( var ) ))
35
+ def write_action_env (var_name , var ):
36
+ write ('build --action_env {}="{}"' . format (var_name , var ))
39
37
40
38
41
39
def is_macos ():
@@ -46,13 +44,6 @@ def is_windows():
46
44
return platform .system () == "Windows"
47
45
48
46
49
- def get_input (question ):
50
- try :
51
- return input (question )
52
- except EOFError :
53
- return ""
54
-
55
-
56
47
def get_tf_header_dir ():
57
48
import tensorflow as tf
58
49
@@ -98,107 +89,42 @@ def create_build_configuration():
98
89
99
90
logging .disable (logging .WARNING )
100
91
101
- write_action_env_to_bazelrc ("TF_HEADER_DIR" , get_tf_header_dir ())
102
- write_action_env_to_bazelrc ("TF_SHARED_LIBRARY_DIR" , get_tf_shared_lib_dir ())
103
- write_action_env_to_bazelrc ("TF_SHARED_LIBRARY_NAME" , get_shared_lib_name ())
104
- write_action_env_to_bazelrc ("TF_CXX11_ABI_FLAG" , tf .sysconfig .CXX11_ABI_FLAG )
105
-
106
- write_to_bazelrc ("build --spawn_strategy=standalone" )
107
- write_to_bazelrc ("build --strategy=Genrule=standalone" )
108
- write_to_bazelrc ("build -c opt" )
109
-
110
- _TF_NEED_CUDA = os .getenv ("TF_NEED_CUDA" )
111
-
112
- while _TF_NEED_CUDA is None :
113
- print ()
114
- answer = get_input ("Do you want to build GPU ops? [y/N] " )
115
- if answer in ("Y" , "y" ):
116
- print ("> Building GPU & CPU ops" )
117
- _TF_NEED_CUDA = "1"
118
- elif answer in ("N" , "n" , "" ):
119
- print ("> Building only CPU ops" )
120
- _TF_NEED_CUDA = "0"
121
- else :
122
- print ("Invalid selection:" , answer )
123
-
124
- if _TF_NEED_CUDA == "1" :
125
- configure_cuda ()
92
+ write_action_env ("TF_HEADER_DIR" , get_tf_header_dir ())
93
+ write_action_env ("TF_SHARED_LIBRARY_DIR" , get_tf_shared_lib_dir ())
94
+ write_action_env ("TF_SHARED_LIBRARY_NAME" , get_shared_lib_name ())
95
+ write_action_env ("TF_CXX11_ABI_FLAG" , tf .sysconfig .CXX11_ABI_FLAG )
126
96
127
- print ()
128
- print ("Build configurations successfully written to" , _TFA_BAZELRC )
129
- print (pathlib .Path (_TFA_BAZELRC ).read_text ())
130
- print ()
97
+ write ("build --spawn_strategy=standalone" )
98
+ write ("build --strategy=Genrule=standalone" )
99
+ write ("build -c opt" )
131
100
101
+ if os .getenv ("TF_NEED_CUDA" , "0" ) == "1" :
102
+ print ("> Building GPU & CPU ops" )
103
+ configure_cuda ()
104
+ else :
105
+ print ("> Building only CPU ops" )
132
106
133
- def get_cuda_toolkit_path ():
134
- default = "/usr/local/cuda"
135
- cuda_toolkit_path = os .getenv ("CUDA_TOOLKIT_PATH" )
136
- if cuda_toolkit_path is None :
137
- answer = get_input (
138
- "Please specify the location of CUDA. [Default is {}]: " .format (default )
139
- )
140
- cuda_toolkit_path = answer or default
141
- print ("> CUDA installation path:" , cuda_toolkit_path )
142
- print ()
143
- return cuda_toolkit_path
144
-
145
-
146
- def get_cudnn_install_path ():
147
- default = "/usr/lib/x86_64-linux-gnu"
148
- cudnn_install_path = os .getenv ("CUDNN_INSTALL_PATH" )
149
- if cudnn_install_path is None :
150
- answer = get_input (
151
- "Please specify the location of cuDNN installation. [Default is {}]: " .format (
152
- default
153
- )
154
- )
155
- cudnn_install_path = answer or default
156
- print ("> cuDNN installation path:" , cudnn_install_path )
157
107
print ()
158
- return cudnn_install_path
108
+ print ("Build configurations successfully written to" , _TFA_BAZELRC , ":\n " )
109
+ print (pathlib .Path (_TFA_BAZELRC ).read_text ())
159
110
160
111
161
112
def configure_cuda ():
162
- _TF_CUDA_VERSION = os .getenv ("TF_CUDA_VERSION" )
163
- _TF_CUDNN_VERSION = os .getenv ("TF_CUDNN_VERSION" )
164
-
165
- print ()
166
- print ("Configuring GPU setup..." )
167
-
168
- if _TF_CUDA_VERSION is None :
169
- answer = get_input (
170
- "Please specify the CUDA version [Default is {}]: " .format (
171
- _DEFAULT_CUDA_VERISON
172
- )
173
- )
174
- _TF_CUDA_VERSION = answer or _DEFAULT_CUDA_VERISON
175
- print ("> Using CUDA version:" , _TF_CUDA_VERSION )
176
- print ()
177
-
178
- if _TF_CUDNN_VERSION is None :
179
- answer = get_input (
180
- "Please specify the cuDNN major version [Default is {}]: " .format (
181
- _DEFAULT_CUDNN_VERSION
182
- )
183
- )
184
- _TF_CUDNN_VERSION = answer or _DEFAULT_CUDNN_VERSION
185
- print ("> Using cuDNN version:" , _TF_CUDNN_VERSION )
186
- print ()
187
-
188
- write_action_env_to_bazelrc ("TF_NEED_CUDA" , "1" )
189
- write_action_env_to_bazelrc ("CUDA_TOOLKIT_PATH" , get_cuda_toolkit_path ())
190
- write_action_env_to_bazelrc ("CUDNN_INSTALL_PATH" , get_cudnn_install_path ())
191
- write_action_env_to_bazelrc ("TF_CUDA_VERSION" , _TF_CUDA_VERSION )
192
- write_action_env_to_bazelrc ("TF_CUDNN_VERSION" , _TF_CUDNN_VERSION )
193
-
194
- write_to_bazelrc ("test --config=cuda" )
195
- write_to_bazelrc ("build --config=cuda" )
196
- write_to_bazelrc (
197
- "build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true"
113
+ write_action_env ("TF_NEED_CUDA" , "1" )
114
+ write_action_env (
115
+ "CUDA_TOOLKIT_PATH" , os .getenv ("CUDA_TOOLKIT_PATH" , "/usr/local/cuda" )
198
116
)
199
- write_to_bazelrc (
200
- "build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain"
117
+ write_action_env (
118
+ "CUDNN_INSTALL_PATH" ,
119
+ os .getenv ("CUDNN_INSTALL_PATH" , "/usr/lib/x86_64-linux-gnu" ),
201
120
)
121
+ write_action_env ("TF_CUDA_VERSION" , os .getenv ("TF_CUDA_VERSION" , "10.1" ))
122
+ write_action_env ("TF_CUDNN_VERSION" , os .getenv ("TF_CUDNN_VERSION" , "7" ))
123
+
124
+ write ("test --config=cuda" )
125
+ write ("build --config=cuda" )
126
+ write ("build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true" )
127
+ write ("build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain" )
202
128
203
129
204
130
if __name__ == "__main__" :
0 commit comments