-
Notifications
You must be signed in to change notification settings - Fork 23
/
loss.py
180 lines (157 loc) · 6.7 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import itertools
from typing import Any, Optional
import tensorflow as tf
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
_EPSILON = tf.keras.backend.epsilon()
def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
class_weight: Optional[Any] = None,
from_logits: bool = False, axis: int = -1
) -> tf.Tensor:
r"""Focal loss function for multiclass classification with integer labels.
Parameters
----------
y_true : tensor-like
Integer class labels.
y_pred : tensor-like
Either probabilities or logits, depending on the `from_logits`
parameter.
gamma : float or tensor-like of shape (K,)
The focusing parameter :math:`\gamma`. Higher values of `gamma` make
easy-to-classify examples contribute less to the loss relative to
hard-to-classify examples. Must be non-negative. This can be a
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether `y_pred` contains logits or probabilities.
axis : int, optional
Channel axis in the `y_pred` tensor.
Returns
-------
:class:`tf.Tensor`
The focal loss for each example.
"""
# Process focusing parameter
gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32)
gamma_rank = gamma.shape.rank
scalar_gamma = gamma_rank == 0
# Process class weight
if class_weight is not None:
class_weight = tf.convert_to_tensor(class_weight,
dtype=tf.dtypes.float32)
# Process prediction tensor
y_pred = tf.convert_to_tensor(y_pred)
y_pred_rank = y_pred.shape.rank
if y_pred_rank is not None:
axis %= y_pred_rank
if axis != y_pred_rank - 1:
# Put channel axis last for sparse_softmax_cross_entropy_with_logits
perm = list(itertools.chain(range(axis),
range(axis + 1, y_pred_rank), [axis]))
y_pred = tf.transpose(y_pred, perm=perm)
elif axis != -1:
raise ValueError(
f'Cannot compute sparse categorical focal loss with axis={axis} on '
'a prediction tensor with statically unknown rank.')
y_pred_shape = tf.shape(y_pred)
# Process ground truth tensor
y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int64)
y_true_rank = y_true.shape.rank
if y_true_rank is None:
raise NotImplementedError('Sparse categorical focal loss not supported '
'for target/label tensors of unknown rank')
reshape_needed = (y_true_rank is not None and y_pred_rank is not None and
y_pred_rank != y_true_rank + 1)
if reshape_needed:
y_true = tf.reshape(y_true, [-1])
y_pred = tf.reshape(y_pred, [-1, y_pred_shape[-1]])
if from_logits:
logits = y_pred
probs = tf.nn.softmax(y_pred, axis=-1)
else:
probs = y_pred
logits = tf.math.log(tf.clip_by_value(y_pred, _EPSILON, 1 - _EPSILON))
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y_true,
logits=logits,
)
y_true_rank = y_true.shape.rank
probs = tf.gather(probs, y_true, axis=-1, batch_dims=y_true_rank)
if not scalar_gamma:
gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank)
focal_modulation = (1 - probs) ** gamma
loss = focal_modulation * xent_loss
if class_weight is not None:
class_weight = tf.gather(class_weight, y_true, axis=0,
batch_dims=y_true_rank)
loss *= class_weight
if reshape_needed:
loss = tf.reshape(loss, y_pred_shape[:-1])
return loss
@tf.keras.utils.register_keras_serializable()
class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
r"""Focal loss function for multiclass classification with integer labels.
Parameters:
----------
gamma : float or tensor-like of shape (K,)
The focusing parameter :math:`\gamma`. Higher values of `gamma` make
easy-to-classify examples contribute less to the loss relative to
hard-to-classify examples. Must be non-negative. This can be a
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether model prediction will be logits or probabilities.
**kwargs : keyword arguments
Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name`
or `reduction`).
--------
"""
def __init__(self, gamma, class_weight: Optional[Any] = None,
from_logits: bool = False, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.class_weight = class_weight
self.from_logits = from_logits
def get_config(self):
"""Returns the config of the layer.
A layer config is a Python dictionary containing the configuration of a
layer. The same layer can be re-instantiated later (without its trained
weights) from this configuration.
Returns
-------
dict
This layer's config.
"""
config = super().get_config()
config.update(gamma=self.gamma, class_weight=self.class_weight,
from_logits=self.from_logits)
return config
def call(self, y_true, y_pred):
"""Compute the per-example focal loss.
This method simply calls
:meth:`~focal_loss.sparse_categorical_focal_loss` with the appropriate
arguments.
Parameters
----------
y_true : tensor-like, shape (N,)
Integer class labels.
y_pred : tensor-like, shape (N, K)
Either probabilities or logits, depending on the `from_logits`
parameter.
Returns
-------
:class:`tf.Tensor`
The per-example focal loss. Reduction to a scalar is handled by
this layer's
:meth:`~focal_loss.SparseCateogiricalFocalLoss.__call__` method.
"""
return sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred,
class_weight=self.class_weight,
gamma=self.gamma,
from_logits=self.from_logits)