Skip to content

Commit

Permalink
fix: Handle NoneType runningmean and runningvar in instance_norm and …
Browse files Browse the repository at this point in the history
…subsequently batch_norm for the torch frontend and the corresponding functional and backend APIs
  • Loading branch information
hmahmood24 committed Aug 29, 2024
1 parent 0af909e commit 80ccd1a
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 55 deletions.
8 changes: 4 additions & 4 deletions ivy/data_classes/array/experimental/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def l2_normalize(

def batch_norm(
self: Union[ivy.NativeArray, ivy.Array],
mean: Union[ivy.NativeArray, ivy.Array],
variance: Union[ivy.NativeArray, ivy.Array],
mean: Optional[Union[ivy.NativeArray, ivy.Array]],
variance: Optional[Union[ivy.NativeArray, ivy.Array]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None,
Expand Down Expand Up @@ -145,8 +145,8 @@ def batch_norm(

def instance_norm(
self: Union[ivy.NativeArray, ivy.Array],
mean: Union[ivy.NativeArray, ivy.Array],
variance: Union[ivy.NativeArray, ivy.Array],
mean: Optional[Union[ivy.NativeArray, ivy.Array]],
variance: Optional[Union[ivy.NativeArray, ivy.Array]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None,
Expand Down
16 changes: 8 additions & 8 deletions ivy/data_classes/container/experimental/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def l2_normalize(
@staticmethod
def static_batch_norm(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
mean: Union[ivy.NativeArray, ivy.Array, ivy.Container],
variance: Union[ivy.NativeArray, ivy.Array, ivy.Container],
mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None,
Expand Down Expand Up @@ -342,8 +342,8 @@ def static_batch_norm(

def batch_norm(
self: Union[ivy.Array, ivy.NativeArray, ivy.Container],
mean: Union[ivy.NativeArray, ivy.Array, ivy.Container],
variance: Union[ivy.NativeArray, ivy.Array, ivy.Container],
mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None,
Expand Down Expand Up @@ -438,8 +438,8 @@ def batch_norm(
@staticmethod
def static_instance_norm(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
mean: Union[ivy.NativeArray, ivy.Array, ivy.Container],
variance: Union[ivy.NativeArray, ivy.Array, ivy.Container],
mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None,
Expand Down Expand Up @@ -532,8 +532,8 @@ def static_instance_norm(

def instance_norm(
self: Union[ivy.Array, ivy.NativeArray, ivy.Container],
mean: Union[ivy.NativeArray, ivy.Array, ivy.Container],
variance: Union[ivy.NativeArray, ivy.Array, ivy.Container],
mean: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
variance: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array, ivy.Container]] = None,
Expand Down
39 changes: 28 additions & 11 deletions ivy/functional/backends/tensorflow/experimental/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def local_response_norm(
@with_unsupported_dtypes({"2.15.0 and below": ("float16", "bfloat16")}, backend_version)
def batch_norm(
x: Union[tf.Tensor, tf.Variable],
mean: Union[tf.Tensor, tf.Variable],
variance: Union[tf.Tensor, tf.Variable],
mean: Optional[Union[tf.Tensor, tf.Variable]],
variance: Optional[Union[tf.Tensor, tf.Variable]],
/,
*,
scale: Optional[Union[tf.Tensor, tf.Variable]] = None,
Expand Down Expand Up @@ -103,9 +103,15 @@ def batch_norm(
dims = (0, *range(1, xdims - 1))
mean = tf.math.reduce_mean(x, axis=dims)
variance = tf.math.reduce_variance(x, axis=dims)
runningmean = (1 - momentum) * runningmean + momentum * mean
runningvariance = (1 - momentum) * runningvariance + momentum * variance * n / (
n - 1
runningmean = (
((1 - momentum) * runningmean + momentum * mean)
if runningmean is not None
else runningmean
)
runningvariance = (
(1 - momentum) * runningvariance + momentum * variance * n / (n - 1)
if runningvariance is not None
else runningvariance
)

inv = 1.0 / tf.math.sqrt(variance + eps)
Expand All @@ -126,8 +132,8 @@ def batch_norm(

def instance_norm(
x: Union[tf.Tensor, tf.Variable],
mean: Union[tf.Tensor, tf.Variable],
variance: Union[tf.Tensor, tf.Variable],
mean: Optional[Union[tf.Tensor, tf.Variable]] = None,
variance: Optional[Union[tf.Tensor, tf.Variable]] = None,
/,
*,
scale: Optional[Union[tf.Tensor, tf.Variable]] = None,
Expand Down Expand Up @@ -161,8 +167,8 @@ def instance_norm(
C = x.shape[-1]
S = x.shape[0:-2]
x = tf.reshape(x, (1, *S, N * C))
mean = tf.tile(mean, [N])
variance = tf.tile(variance, [N])
mean = tf.tile(mean, [N]) if mean is not None else mean
variance = tf.tile(variance, [N]) if variance is not None else variance
if scale is not None:
scale = tf.tile(scale, [N])
if offset is not None:
Expand All @@ -187,10 +193,21 @@ def instance_norm(
xnormalized, perm=(xdims - 2, *range(0, xdims - 2), xdims - 1)
)

runningmean = (
tf.reduce_mean(tf.reshape(runningmean, (N, C)), axis=0)
if runningmean is not None
else runningmean
)
runningvariance = (
tf.reduce_mean(tf.reshape(runningvariance, (N, C)), axis=0)
if runningvariance is not None
else runningvariance
)

return (
xnormalized,
tf.reduce_mean(tf.reshape(runningmean, (N, C)), axis=0),
tf.reduce_mean(tf.reshape(runningvariance, (N, C)), axis=0),
runningmean,
runningvariance,
)


Expand Down
26 changes: 13 additions & 13 deletions ivy/functional/backends/torch/experimental/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def local_response_norm(
@with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, backend_version)
def batch_norm(
x: torch.Tensor,
mean: torch.Tensor,
variance: torch.Tensor,
mean: Optional[torch.Tensor],
variance: Optional[torch.Tensor],
/,
*,
scale: Optional[torch.Tensor] = None,
Expand All @@ -74,8 +74,8 @@ def batch_norm(
xdims = x.ndim
if data_format == "NSC":
x = torch.permute(x, dims=(0, xdims - 1, *range(1, xdims - 1)))
runningmean = mean.detach().clone()
runningvariance = variance.detach().clone()
runningmean = mean.detach().clone() if mean is not None else mean
runningvariance = variance.detach().clone() if variance is not None else variance
xnormalized = torch.nn.functional.batch_norm(
x,
runningmean,
Expand All @@ -94,8 +94,8 @@ def batch_norm(
batch_norm.partial_mixed_handler = (
lambda x, mean, variance, scale=None, offset=None, **kwargs: (
x.ndim > 1
and mean.ndim == 1
and variance.ndim == 1
and (mean is None or mean.ndim == 1)
and (variance is None or variance.ndim == 1)
and (scale is None or scale.ndim == 1)
and (offset is None or offset.ndim == 1)
)
Expand All @@ -105,8 +105,8 @@ def batch_norm(
@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version)
def instance_norm(
x: torch.Tensor,
mean: torch.Tensor,
variance: torch.Tensor,
mean: Optional[torch.Tensor] = None,
variance: Optional[torch.Tensor] = None,
/,
*,
scale: Optional[torch.Tensor] = None,
Expand All @@ -117,8 +117,8 @@ def instance_norm(
data_format: Optional[str] = "NSC",
out: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
runningmean = mean.clone()
runningvariance = variance.clone()
runningmean = mean.clone() if mean is not None else mean
runningvariance = variance.clone() if variance is not None else variance
# reshape from N, *S, C to N, C, *S
xdims = x.ndim
if data_format == "NSC":
Expand All @@ -140,10 +140,10 @@ def instance_norm(


instance_norm.partial_mixed_handler = (
lambda x, mean, variance, scale=None, offset=None, **kwargs: (
lambda x, mean=None, variance=None, scale=None, offset=None, **kwargs: (
x.ndim > 1
and mean.ndim == 1
and variance.ndim == 1
and (mean is None or mean.ndim == 1)
and (variance is None or variance.ndim == 1)
and (scale is None or scale.ndim == 1)
and (offset is None or offset.ndim == 1)
)
Expand Down
20 changes: 12 additions & 8 deletions ivy/functional/frontends/torch/nn/functional/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
@to_ivy_arrays_and_back
def batch_norm(
input,
running_mean,
running_var,
running_mean=None,
running_var=None,
weight=None,
bias=None,
training=False,
Expand All @@ -35,8 +35,10 @@ def batch_norm(
data_format="NCS",
)
if training:
ivy.inplace_update(running_mean, mean)
ivy.inplace_update(running_var, var)
if running_mean is not None:
ivy.inplace_update(running_mean, mean)
if running_var is not None:
ivy.inplace_update(running_var, var)
return normalized


Expand Down Expand Up @@ -68,8 +70,8 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
@to_ivy_arrays_and_back
def instance_norm(
input,
running_mean,
running_var,
running_mean=None,
running_var=None,
weight=None,
bias=None,
use_input_stats=False,
Expand All @@ -87,8 +89,10 @@ def instance_norm(
momentum=momentum,
data_format="NCS",
)
ivy.inplace_update(running_mean, mean)
ivy.inplace_update(running_var, var)
if running_mean is not None:
ivy.inplace_update(running_mean, mean)
if running_var is not None:
ivy.inplace_update(running_var, var)
return normalized


Expand Down
36 changes: 25 additions & 11 deletions ivy/functional/ivy/experimental/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def local_response_norm(
@handle_array_function
def batch_norm(
x: Union[ivy.NativeArray, ivy.Array],
mean: Union[ivy.NativeArray, ivy.Array],
variance: Union[ivy.NativeArray, ivy.Array],
mean: Optional[Union[ivy.NativeArray, ivy.Array]],
variance: Optional[Union[ivy.NativeArray, ivy.Array]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None,
Expand Down Expand Up @@ -270,9 +270,15 @@ def batch_norm(
dims = (0, *range(1, xdims - 1))
mean = ivy.mean(x, axis=dims)
variance = ivy.var(x, axis=dims)
runningmean = (1 - momentum) * runningmean + momentum * mean
runningvariance = (1 - momentum) * runningvariance + momentum * variance * n / (
n - 1
runningmean = (
(1 - momentum) * runningmean + momentum * mean
if runningmean is not None
else runningmean
)
runningvariance = (
(1 - momentum) * runningvariance + momentum * variance * n / (n - 1)
if runningvariance is not None
else runningvariance
)
inv = 1.0 / ivy.sqrt(variance + eps)
offset = 0 if offset is None else offset
Expand Down Expand Up @@ -313,8 +319,8 @@ def batch_norm(
@handle_array_function
def instance_norm(
x: Union[ivy.NativeArray, ivy.Array],
mean: Union[ivy.NativeArray, ivy.Array],
variance: Union[ivy.NativeArray, ivy.Array],
mean: Optional[Union[ivy.NativeArray, ivy.Array]],
variance: Optional[Union[ivy.NativeArray, ivy.Array]],
/,
*,
offset: Optional[Union[ivy.NativeArray, ivy.Array]] = None,
Expand Down Expand Up @@ -387,8 +393,8 @@ def instance_norm(
C = x.shape[-1]
S = x.shape[0:-2]
x = x.reshape((1, *S, N * C))
mean = ivy.tile(mean, N)
variance = ivy.tile(variance, N)
mean = ivy.tile(mean, N) if mean is not None else mean
variance = ivy.tile(variance, N) if variance is not None else variance
if scale is not None:
scale = ivy.tile(scale, N)
if offset is not None:
Expand All @@ -414,8 +420,16 @@ def instance_norm(
xnormalized, axes=(xdims - 2, *range(0, xdims - 2), xdims - 1)
)

runningmean = runningmean.reshape((N, C)).mean(axis=0)
runningvariance = runningvariance.reshape((N, C)).mean(axis=0)
runningmean = (
runningmean.reshape((N, C)).mean(axis=0)
if runningmean is not None
else runningmean
)
runningvariance = (
runningvariance.reshape((N, C)).mean(axis=0)
if runningvariance is not None
else runningvariance
)

if ivy.exists(out):
xnormalized = ivy.inplace_update(out[0], xnormalized)
Expand Down

0 comments on commit 80ccd1a

Please sign in to comment.