Skip to content

Commit

Permalink
fix: Fixed the API for torch.std and torch.std_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
hmahmood24 committed Aug 29, 2024
1 parent c528a41 commit 0af909e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ivy/functional/frontends/torch/reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,16 +335,18 @@ def quantile(input, q, dim=None, keepdim=False, *, interpolation="linear", out=N
@numpy_to_torch_style_args
@to_ivy_arrays_and_back
@with_unsupported_dtypes({"2.2 and below": ("float16", "bool", "integer")}, "torch")
def std(input, dim=None, unbiased=True, keepdim=False, *, out=None):
return ivy.std(input, axis=dim, correction=int(unbiased), keepdims=keepdim, out=out)
def std(input, dim=None, *, correction=1, keepdim=False, out=None):
return ivy.std(
input, axis=dim, correction=int(correction), keepdims=keepdim, out=out
)


@numpy_to_torch_style_args
@to_ivy_arrays_and_back
@with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch")
def std_mean(input, dim, unbiased, keepdim=False, *, out=None):
def std_mean(input, dim, *, correction=1, keepdim=False, out=None):
temp_std = ivy.std(
input, axis=dim, correction=int(unbiased), keepdims=keepdim, out=out
input, axis=dim, correction=int(correction), keepdims=keepdim, out=out
)
temp_mean = ivy.mean(input, axis=dim, keepdims=keepdim, out=out)
return temp_std, temp_mean
Expand Down

0 comments on commit 0af909e

Please sign in to comment.