@@ -70,6 +70,34 @@ def get_bdim_choices(num_tensors):
7070 assert choices [- 1 ] == (None ,) * num_tensors
7171 return tuple (choices [:- 1 ])
7272
73+ # NB: This is O(2 ** num_tensors).
74+ # num_tensors ranges from 1 to 10, with 2-4 being most common.
75+ # Try not to extravagate it if you're modifying it.
76+ def get_bdim_choices_batch_norm (num_tensors , _ , running_mean = None , running_var = None , * args ):
77+ choices = []
78+ options = (- 1 , None )
79+
80+ # instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
81+ if running_mean == None or running_var == None :
82+ choices .append ((None ,) + (0 ,) * (num_tensors - 1 ))
83+ for choice in itertools .product (options , repeat = num_tensors - 1 ):
84+ choices .append ((None ,) + choice )
85+
86+ else :
87+ # running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
88+ # running_mean/var are unbatched, so this tests all other cases
89+ choices .append ((0 ,) * num_tensors )
90+ for choice in itertools .product (options , repeat = num_tensors ):
91+ input_bdim = choice [0 ]
92+ running_mean_bdim = choice [1 ]
93+ running_var_bdim = choice [2 ]
94+ if input_bdim and (not running_mean_bdim or not running_var_bdim ):
95+ continue
96+ choices .append (choice )
97+
98+ assert choices [- 1 ] == (None ,) * num_tensors
99+ return tuple (choices [:- 1 ])
100+
73101
74102def add_batch_dim (arg , bdim , batch_size = 3 ):
75103 assert bdim == 0 or bdim == - 1
@@ -93,12 +121,7 @@ def construct_in_dims(bdim_choice_for_tensors, is_tensors):
93121 result .append (next (bdim ))
94122 return tuple (result )
95123
96-
97- def get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size = 2 , * , for_batch_norm = False ):
98- if for_batch_norm :
99- # TODO: delete this path
100- return get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size )
101-
124+ def get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size = 2 ):
102125 flat_args , arg_spec = pytree .tree_flatten (tuple (arg_values ))
103126 is_tensors = [isinstance (a , torch .Tensor ) for a in flat_args ]
104127 bdim_choices = get_bdim_choices (sum (is_tensors ))
@@ -120,87 +143,41 @@ def get_batched_arg(arg, bdim):
120143 yield batched_args , in_dims , kwarg_values
121144
122145
123- def get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size = 3 , bdims = (0 , - 1 )):
124- for_batch_norm = True
125- assert bdims == (0 ,) or bdims == (0 , - 1 )
126-
127- def add_batch_dim (arg , bdim , batch_size = 3 ):
128- assert bdim == 0 or bdim == - 1
129- if isinstance (arg , torch .Tensor ):
130- if bdim == 0 :
131- shape = [1 ] * len (arg .shape )
132- shape .insert (bdim , batch_size )
133- return (arg .repeat (shape ), bdim )
134- if bdim == - 1 :
135- arg = arg .unsqueeze (- 1 ).expand (* arg .shape , batch_size ).contiguous ()
136- return (arg , bdim )
137- assert False
138- else :
139- return (arg , None )
140- for bdim in bdims :
141- batch_choices = []
142-
143- def add_batch_choices (a ):
144- if isinstance (a , torch .Tensor ):
145- batched_val = add_batch_dim (a , bdim , batch_size )
146- batch_choices .append ((batched_val , (a , None )))
147- else :
148- batch_choices .append (((a , None ),))
149-
150- flat_args , arg_spec = pytree .tree_flatten (tuple (arg_values ))
151- if for_batch_norm :
152- # Batch norm is unique because the running_mean and running_var are updated in place.
153- # Therefore, they cannot be unbatched if the input is batched. The case where both are
154- # unbatched is added at the end
155- if len (flat_args ) >= 3 :
156- add_batch_choices (flat_args [0 ]) # input can be batched or unbatched
157- batch_choices .append ((add_batch_dim (flat_args [1 ], bdim , batch_size ),)) # running_mean must be batched
158- batch_choices .append ((add_batch_dim (flat_args [2 ], bdim , batch_size ),)) # running_var must be batched
159- orig_flat_args = flat_args
160- flat_args = orig_flat_args [3 :]
161- else :
162- # TODO: None defaults in instance norm create empty tensors that are written to and mean that we must
163- # have unbatched inputs. None in the running mean/running var shouldn't make a tensor
164- batch_choices .append (((flat_args [0 ], None ),)) # input must be unbatched
165- if len (flat_args ) == 2 :
166- batch_choices .append ((add_batch_dim (flat_args [1 ], bdim , batch_size ),))
167- orig_flat_args = flat_args
168- flat_args = []
169-
170- for arg in flat_args :
171- add_batch_choices (arg )
172-
173- for batched_values in itertools .product (* batch_choices ):
174- batched_args , in_dims = zip (* batched_values )
175-
176- if all ([i is None for i in in_dims ]):
177- continue
178-
179- yield pytree .tree_unflatten (batched_args , arg_spec ), pytree .tree_unflatten (in_dims , arg_spec ), kwarg_values
146+ def get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size = 2 ):
147+ flat_args , arg_spec = pytree .tree_flatten (tuple (arg_values ))
148+ is_tensors = [isinstance (a , torch .Tensor ) for a in flat_args ]
149+ num_tensors = sum (is_tensors )
150+ if num_tensors == 1 : # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors
151+ return
152+ bdim_choices = get_bdim_choices_batch_norm (num_tensors , * arg_values )
180153
181- if for_batch_norm and len (orig_flat_args ) >= 2 :
182- # Adds the case where input, running_mean, and running_var are all unbatched
183- batch_choices [0 ] = ((orig_flat_args [0 ], None ),)
184- batch_choices [1 ] = ((orig_flat_args [1 ], None ),)
185- if len (orig_flat_args ) >= 3 :
186- batch_choices [2 ] = ((orig_flat_args [2 ], None ),)
187- for batched_values in itertools .product (* batch_choices ):
188- batched_args , in_dims = zip (* batched_values )
154+ @memoize
155+ def get_batched_arg (arg , bdim ):
156+ assert isinstance (arg , torch .Tensor )
157+ assert bdim is not None
158+ result , _ = add_batch_dim (arg , bdim , batch_size )
159+ return result
189160
190- if all ([ i is None for i in in_dims ]) :
191- continue
161+ for bdim_choice in bdim_choices :
162+ flat_in_dims = construct_in_dims ( bdim_choice , is_tensors )
192163
193- batched_args_tuple = pytree .tree_unflatten (batched_args , arg_spec )
194- in_dims_tuple = pytree .tree_unflatten (in_dims , arg_spec )
195- yield batched_args_tuple , in_dims_tuple , kwarg_values
164+ flat_batched_args = tuple (arg if in_dim is None else get_batched_arg (arg , in_dim )
165+ for arg , in_dim in zip (flat_args , flat_in_dims ))
166+ batched_args = pytree .tree_unflatten (flat_batched_args , arg_spec )
167+ in_dims = pytree .tree_unflatten (flat_in_dims , arg_spec )
168+ yield batched_args , in_dims , kwarg_values
196169
197170
198171def get_fallback_and_vmap_exhaustive (op , arg_values , kwarg_values , opinfo = None , compute_loop_out = True ):
199172 out_dim = 0
200173 batch_size = 2
201174 batch_norm_fns = ("nn.functional.batch_norm" , "nn.functional.instance_norm" ) # instance norm calls batch norm
202- for_batch_norm = opinfo is not None and opinfo .name in batch_norm_fns
203- generator = get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size , for_batch_norm = for_batch_norm )
175+
176+ if opinfo is not None and opinfo .name in batch_norm_fns :
177+ generator = get_exhaustive_batched_inputs_batch_norm (arg_values , kwarg_values , batch_size )
178+ else :
179+ generator = get_exhaustive_batched_inputs (arg_values , kwarg_values , batch_size )
180+
204181 for batched_args , in_dims , kwarg_values in generator :
205182 if compute_loop_out :
206183 loop_out = loop (op , in_dims , out_dim , batch_size , * batched_args , ** kwarg_values )
0 commit comments