Skip to content

Commit

Permalink
Minor cleanup for PR #260 (#537)
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 authored Jun 5, 2024
1 parent fbaa3a8 commit 0c6e38c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 54 deletions.
3 changes: 0 additions & 3 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,9 +1096,6 @@ def _div_prim_grad(a: Number | TensorProxy, b: Number | TensorProxy, /) -> Numbe
register_grad(pids.GT, prims.gt)
register_grad(pids.LE, prims.le)
register_grad(pids.LT, prims.lt)
register_grad(pids.NE, prims.ne)
register_grad(pids.GT, prims.gt)
register_grad(pids.LE, prims.le)


@torchctx
Expand Down
8 changes: 4 additions & 4 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7858,12 +7858,12 @@ def cross_entropy_error_generator(op, device, dtype=torch.float32, **kwargs):
"Expected the input tensor to have (.*?) dimensions, but it has (.*?) dimensions.",
)

# target shape is input shape except channels dimension
# target shape is input shape except class dimension
incorrect_batch_target = make((10,), low=0, high=C, dtype=torch.long, requires_grad=False)
yield (
SampleInput(valid_input, incorrect_batch_target),
RuntimeError,
"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \
"Expected the target tensor to have the same shape as the input tensor except for the class dimension \
(.*?), but it has shape (.*?).",
)

Expand Down Expand Up @@ -8013,12 +8013,12 @@ def nll_loss_error_generator(op, device, dtype=torch.float32, **kwargs):
"Expected the input tensor to have (.*?) dimensions, but it has (.*?) dimensions.",
)

# target shape is input shape except channels dimension
# target shape is input shape except class dimension
incorrect_batch_target = make((10,), low=0, high=C, dtype=torch.long, requires_grad=False)
yield (
SampleInput(valid_input, incorrect_batch_target),
RuntimeError,
"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \
"Expected the target tensor to have the same shape as the input tensor except for the class dimension \
(.*?), but it has shape (.*?).",
)

Expand Down
94 changes: 47 additions & 47 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3578,9 +3578,9 @@ def cross_entropy(

_cross_entropy_input_checks(a, target, weight, ignore_index, reduction, label_smoothing)

# channels dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# class dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# or right next to it (i.e. a.shape[1]).
channels_dim = 1 if a.ndim >= 2 else 0
class_dim = 1 if a.ndim >= 2 else 0

# NOTE This short-circuit is subject to change and is placed ahead of other input checks to match PyTorch behavior.
# The expected behavior when the target and input have zero elements:
Expand All @@ -3591,7 +3591,7 @@ def cross_entropy(
if a.numel() == 0:
if reduction == "none":
output_shape = list(a.shape)
output_shape.pop(channels_dim)
output_shape.pop(class_dim)
return full(output_shape, 0.0, device=a.device, dtype=a.dtype)
elif reduction == "sum":
return full(result_shape := [], fill_value := 0.0, device=a.device, dtype=a.dtype)
Expand All @@ -3603,7 +3603,7 @@ def cross_entropy(
elif label_smoothing != 0.0:
return _cross_entropy_loss_label_smoothing(a, target, weight, ignore_index, reduction, label_smoothing)
else:
log_softmax_input = log_softmax(a, dim=channels_dim)
log_softmax_input = log_softmax(a, dim=class_dim)
return nll_loss(log_softmax_input, target, weight, ignore_index, reduction)


Expand Down Expand Up @@ -3632,14 +3632,14 @@ def _cross_entropy_input_checks(
lambda: f"Expected label_smoothing to be in [0, 1] range but got {label_smoothing}.",
)

# channels dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# class dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# or right next to it (i.e. a.shape[1]).
channels_dim = 1 if a.ndim >= 2 else 0
num_channels = a.shape[channels_dim]
class_dim = 1 if a.ndim >= 2 else 0
num_class = a.shape[class_dim]

utils.check(
weight is None or (weight.ndim == 1 and weight.shape[0] == num_channels),
lambda: f"Expected a 1D tensor with {num_channels} elements for weight argument, \
weight is None or (weight.ndim == 1 and weight.shape[0] == num_class),
lambda: f"Expected a 1D tensor with {num_class} elements for weight argument, \
but found a tensor with {weight.ndim} dimensions and {weight.shape[0]} elements.",
)

Expand All @@ -3654,13 +3654,13 @@ def _cross_entropy_input_checks(
lambda: f"Expected the input tensor to have {(target.ndim + 1)=} dimensions, but it has {a.ndim} dimensions.",
)

# target should match input in dims which do not correspond to the channels dim, i.e.
# (input.shape[:channels_dim] + input.shape[channels_dim + 1:]) == target.shape <=> True
expected_target_shape = a.shape[:channels_dim] + a.shape[channels_dim + 1 :]
# target should match input in dims which do not correspond to the class dim, i.e.
# (input.shape[:class_dim] + input.shape[class_dim + 1:]) == target.shape <=> True
expected_target_shape = a.shape[:class_dim] + a.shape[class_dim + 1 :]

utils.check(
expected_target_shape == target.shape,
lambda: f"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \
lambda: f"Expected the target tensor to have the same shape as the input tensor except for the class dimension \
{expected_target_shape}, but it has shape {target.shape}.",
)
else:
Expand All @@ -3685,28 +3685,28 @@ def _cross_entropy_loss_probability_target(
reduction: str,
label_smoothing: float,
) -> TensorLike:
# channels dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# class dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# or right next to it (i.e. a.shape[1]).
channels_dim = 1 if a.ndim >= 2 else 0
num_channels = a.shape[channels_dim]
class_dim = 1 if a.ndim >= 2 else 0
num_class = a.shape[class_dim]

if label_smoothing > 0.0:
target = (target * (1 - label_smoothing)) + (label_smoothing / num_channels)
target = (target * (1 - label_smoothing)) + (label_smoothing / num_class)

out = log_softmax(a, dim=channels_dim) * target
out = log_softmax(a, dim=class_dim) * target

if weight is not None:
bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, a.ndim)])
bcast_weight = reshape(weight, [num_class] + [1 for _ in range(2, a.ndim)])
out = out * bcast_weight

out = -out

if reduction == "none":
return sum(out, dim=channels_dim)
return sum(out, dim=class_dim)
elif reduction == "sum":
return sum(out)
elif reduction == "mean":
return sum(out) / (a.numel() // num_channels)
return sum(out) / (a.numel() // num_class)


def _cross_entropy_loss_label_smoothing(
Expand All @@ -3718,20 +3718,20 @@ def _cross_entropy_loss_label_smoothing(
reduction: str,
label_smoothing: int,
) -> TensorLike:
# channels dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# class dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# or right next to it (i.e. a.shape[1]).
channels_dim = 1 if a.ndim >= 2 else 0
num_channels = a.shape[channels_dim]
class_dim = 1 if a.ndim >= 2 else 0
num_class = a.shape[class_dim]

log_softmax_value = log_softmax(a, dim=channels_dim)
log_softmax_value = log_softmax(a, dim=class_dim)

if weight is not None:
bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, len(a.shape))])
bcast_weight = reshape(weight, [num_class] + [1 for _ in range(2, len(a.shape))])
out = -(log_softmax_value * bcast_weight)
else:
out = -log_softmax_value

smooth_loss = sum(out, dim=channels_dim)
smooth_loss = sum(out, dim=class_dim)

# Make target broadcastable with output, which has same shape as input tensor.
selected_target_mask = target != ignore_index
Expand All @@ -3749,8 +3749,8 @@ def _cross_entropy_loss_label_smoothing(
# Sum together all target weights.
# Make target broadcastable with output, which has same shape as input tensor.
expanded_weight = expand(bcast_weight, a.shape)
bcast_target = unsqueeze(target, channels_dim)
selected_weight = take_along_dim(expanded_weight, bcast_target, channels_dim)
bcast_target = unsqueeze(target, class_dim)
selected_weight = take_along_dim(expanded_weight, bcast_target, class_dim)
selected_weight = where(selected_target_mask, squeeze(selected_weight), 0)
ret = reduced_sum / sum(selected_weight)
else:
Expand All @@ -3760,7 +3760,7 @@ def _cross_entropy_loss_label_smoothing(

nll_loss_value = nll_loss(log_softmax_value, target, weight, ignore_index, reduction)

return (nll_loss_value * (1.0 - label_smoothing)) + (ret * (label_smoothing / num_channels))
return (nll_loss_value * (1.0 - label_smoothing)) + (ret * (label_smoothing / num_class))


# TODO Is this a method?
Expand Down Expand Up @@ -4128,30 +4128,30 @@ def _nll_loss_helper(
lambda: f"Expected the input tensor to have {(target.ndim + 1)=} dimensions, but it has {a.ndim} dimensions.",
)

# channels dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# class dimension is either the first one if no batch dim present (i.e. a.shape[0]),
# or right next to it (i.e. a.shape[1]).
channels_dim = 1 if a.ndim >= 2 else 0
num_channels = a.shape[channels_dim]
# target should match input in dims which do not correspond to the channels dim, i.e.
# (input.shape[:channels_dim] + input.shape[channels_dim + 1:]) == target.shape <=> True
expected_target_shape = a.shape[:channels_dim] + a.shape[channels_dim + 1 :]
class_dim = 1 if a.ndim >= 2 else 0
num_class = a.shape[class_dim]
# target should match input in dims which do not correspond to the class dim, i.e.
# (input.shape[:class_dim] + input.shape[class_dim + 1:]) == target.shape <=> True
expected_target_shape = a.shape[:class_dim] + a.shape[class_dim + 1 :]

utils.check(
expected_target_shape == target.shape,
lambda: f"Expected the target tensor to have the same shape as the input tensor except for the channels dimension \
lambda: f"Expected the target tensor to have the same shape as the input tensor except for the class dimension \
{expected_target_shape}, but it has shape {target.shape}.",
)

utils.check(
weight is None or (weight.ndim == 1 and weight.shape[0] == num_channels),
lambda: f"Expected a 1D tensor with {num_channels} elements for weight argument, \
weight is None or (weight.ndim == 1 and weight.shape[0] == num_class),
lambda: f"Expected a 1D tensor with {num_class} elements for weight argument, \
but found a tensor with {weight.ndim} dimensions and {weight.shape[0]} elements.",
)

# NOTE: [Handling of 'ignore_index' parameter]
# What does it mean to ignore an index?
# The 'ignore_index' parameter specifies a target value that does not contribute to input gradient.
# 'ignore_index' can be outside of the [0, num_channels) range, which can cause out-of-bounds errors when gathering
# 'ignore_index' can be outside of the [0, num_class) range, which can cause out-of-bounds errors when gathering
# values from input tensor.
#
# What does ATen do?
Expand All @@ -4160,12 +4160,12 @@ def _nll_loss_helper(
#
# What do we do?
# We mask the ignore_index entries on the output tensor from take_along_axis because we expect the targets to be
# within [0, num_channels) range.
# within [0, num_class) range.
#
# Why do we like our approach better?
# Mimicking Aten behavior requires masking the target tensor before calling take_along_axis, which would add more
# operations to the fusion. We should follow this approach until we see real examples where ignore_index is
# out-of-bounds of [0, num_channels) range.
# out-of-bounds of [0, num_class) range.
#
# What are the alternative options?
# We can add a `mode` parameter to take_along_axis that controls how to handle out-of-bounds indices.
Expand All @@ -4174,20 +4174,20 @@ def _nll_loss_helper(
out = -a

if weight is not None:
bcast_weight = reshape(weight, [num_channels] + [1 for _ in range(2, a.ndim)])
bcast_weight = reshape(weight, [num_class] + [1 for _ in range(2, a.ndim)])
out = out * bcast_weight

# Make target broadcastable with output, which has same shape as input tensor.
bcast_target = unsqueeze(target, channels_dim)
bcast_target = unsqueeze(target, class_dim)

out = take_along_dim(out, bcast_target, channels_dim)
out = take_along_dim(out, bcast_target, class_dim)
selected_target_mask = bcast_target != ignore_index
out = where(selected_target_mask, out, 0)

# This section handles applying the reduction parameter to the output.
# We return None for the total_weight when reduction is "none" or "sum" since it is unused in the backwards pass.
if reduction == "none":
return squeeze(out, channels_dim), None
return squeeze(out, class_dim), None
elif reduction == "sum":
return sum(out), None
elif reduction == "mean":
Expand All @@ -4197,7 +4197,7 @@ def _nll_loss_helper(
# Mask the ignored target classes.
# Sum together all target weights.
expanded_weight = expand(bcast_weight, a.shape)
selected_weight = take_along_dim(expanded_weight, bcast_target, channels_dim)
selected_weight = take_along_dim(expanded_weight, bcast_target, class_dim)
selected_weight = where(selected_target_mask, selected_weight, 0)
bcast_weight_sum = sum(selected_weight)
return (reduced_sum / bcast_weight_sum), bcast_weight_sum
Expand Down

0 comments on commit 0c6e38c

Please sign in to comment.