Skip to content

Commit

Permalink
Add Soft Weighted Kappa Loss (tensorflow#762)
Browse files Browse the repository at this point in the history
* add weighted kappa loss

* add unit tests

* change some docs

* change python files format

* shorten some lines

* rename and update README and BUILD

* resolve conversations

* resolve converstions

* remove escape

* reformat tensorflow_addons/losses/kappa_loss* with black

* reformat code

* reformat code

* reformat code with black

* Update tensorflow_addons/losses/kappa_loss.py

Co-Authored-By: Gabriel de Marmiesse <gabrieldemarmiesse@gmail.com>

* [KappaLoss] change according to review

* Update tensorflow_addons/losses/kappa_loss.py

Co-Authored-By: Gabriel de Marmiesse <gabrieldemarmiesse@gmail.com>

* Update tensorflow_addons/losses/kappa_loss.py

Co-Authored-By: Gabriel de Marmiesse <gabrieldemarmiesse@gmail.com>

* [KappaLoss] change accroding to code review

* [KappaLoss] change code format

* [SoftKappaLoss] mv kappa_loss_test.py to losses/tests

* Update .github/CODEOWNERS

Co-Authored-By: Gabriel de Marmiesse <gabrieldemarmiesse@gmail.com>

* [SoftKappaLoss] refine codes according to code review

* [SoftKappaLoss] reformat codes

* [SoftKappaLoss] fix np_deep not defined

* [SoftKappaLoss] fix tests problem

* [SoftKappaLoss] unnecessary change to tigger CI

* Default value for the seed is not needed.

Co-authored-by: gabrieldemarmiesse <gabrieldemarmiesse@gmail.com>
  • Loading branch information
wenmin-wu and gabrieldemarmiesse authored Apr 13, 2020
1 parent 1fdb5b1 commit 6810fb3
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
/tensorflow_addons/losses/tests/sparsemax_loss_test.py @andreasmadsen
/tensorflow_addons/losses/triplet.py @lc0
/tensorflow_addons/losses/tests/triplet_test.py @lc0
/tensorflow_addons/losses/kappa_loss.py @wenmin-wu
/tensorflow_addons/losses/tests/kappa_loss_test.py @wenmin-wu

/tensorflow_addons/metrics/cohens_kappa.py @aakashkumarnain
/tensorflow_addons/metrics/tests/cohens_kappa_test.py @aakashkumarnain
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
npairs_multilabel_loss,
NpairsMultilabelLoss,
)
from tensorflow_addons.losses.kappa_loss import WeightedKappaLoss
132 changes: 132 additions & 0 deletions tensorflow_addons/losses/kappa_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements Weighted kappa loss."""

import tensorflow as tf
from tensorflow_addons.utils.types import Number
from typeguard import typechecked
from typing import Optional


@tf.keras.utils.register_keras_serializable(package="Addons")
class WeightedKappaLoss(tf.keras.losses.Loss):
"""Implements the Weighted Kappa loss function.
Weighted Kappa loss was introduced in the
[Weighted kappa loss function for multi-class classification
of ordinal data in deep learning]
(https://www.sciencedirect.com/science/article/abs/pii/S0167865517301666).
Weighted Kappa is widely used in Ordinal Classification Problems.
The loss value lies in [-inf, log 2], where log 2
means the random prediction.
Usage:
```python
kappa_loss = WeightedKappaLoss(num_classes=4)
y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
[1, 0, 0, 0], [0, 0, 0, 1]])
y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
[0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
loss = kappa_loss(y_true, y_pred)
print('Loss: ', loss.numpy()) # Loss: -1.1611923
```
Usage with `tf.keras` API:
```python
# outputs should be softmax results
# if you want to weight the samples, just multiply the outputs
# by the sample weight.
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tfa.losses.WeightedKappa(num_classes=4))
```
"""

@typechecked
def __init__(
self,
num_classes: int,
weightage: Optional[str] = "quadratic",
name: Optional[str] = "cohen_kappa_loss",
epsilon: Optional[Number] = 1e-6,
dtype: Optional[tf.DType] = tf.float32,
reduction: str = tf.keras.losses.Reduction.NONE,
):
"""Creates a `WeightedKappa` instance.
Args:
num_classes: Number of unique classes in your dataset.
weightage: (Optional) Weighting to be considered for calculating
kappa statistics. A valid value is one of
['linear', 'quadratic']. Defaults to `quadratic` since it's
mostly used.
name: (Optional) String name of the metric instance.
epsilon: (Optional) increment to avoid log zero,
so the loss will be log(1 - k + epsilon), where k belongs to
[-1, 1], usually you can use the default value which is 1e-6.
dtype: (Optional) Data type of the metric result.
Defaults to `tf.float32`.
Raises:
ValueError: If the value passed for `weightage` is invalid
i.e. not any one of ['linear', 'quadratic']
"""

super().__init__(name=name, reduction=reduction)

if weightage not in ("linear", "quadratic"):
raise ValueError("Unknown kappa weighting type.")

self.weightage = weightage
self.num_classes = num_classes
self.epsilon = epsilon
self.dtype = dtype
label_vec = tf.range(num_classes, dtype=dtype)
self.row_label_vec = tf.reshape(label_vec, [1, num_classes])
self.col_label_vec = tf.reshape(label_vec, [num_classes, 1])
col_mat = tf.tile(self.col_label_vec, [1, num_classes])
row_mat = tf.tile(self.row_label_vec, [num_classes, 1])
if weightage == "linear":
self.weight_mat = tf.abs(col_mat - row_mat)
else:
self.weight_mat = (col_mat - row_mat) ** 2

def call(self, y_true, y_pred):
y_true = tf.cast(y_true, dtype=self.dtype)
batch_size = tf.shape(y_true)[0]
cat_labels = tf.matmul(y_true, self.col_label_vec)
cat_label_mat = tf.tile(cat_labels, [1, self.num_classes])
row_label_mat = tf.tile(self.row_label_vec, [batch_size, 1])
if self.weightage == "linear":
weight = tf.abs(cat_label_mat - row_label_mat)
else:
weight = (cat_label_mat - row_label_mat) ** 2
numerator = tf.reduce_sum(weight * y_pred)
label_dist = tf.reduce_sum(y_true, axis=0, keepdims=True)
pred_dist = tf.reduce_sum(y_pred, axis=0, keepdims=True)
w_pred_dist = tf.matmul(self.weight_mat, pred_dist, transpose_b=True)
denominator = tf.reduce_sum(tf.matmul(label_dist, w_pred_dist))
denominator /= tf.cast(batch_size, dtype=self.dtype)
loss = tf.math.divide_no_nan(numerator, denominator)
return tf.math.log(loss + self.epsilon)

def get_config(self):
config = {
"num_classes": self.num_classes,
"weightage": self.weightage,
"epsilon": self.epsilon,
"dtype": self.dtype,
}
base_config = super().get_config()
return {**base_config, **config}
92 changes: 92 additions & 0 deletions tensorflow_addons/losses/tests/kappa_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Weighted Kappa Loss."""

import pytest
import numpy as np
import tensorflow as tf
from tensorflow_addons.losses.kappa_loss import WeightedKappaLoss


def weighted_kappa_loss_np(y_true, y_pred, weightage="quadratic", eps=1e-6):
num_samples, num_classes = y_true.shape
cat_labels = y_true.argmax(axis=1).reshape((-1, 1))
label_mat = np.tile(cat_labels, (1, num_classes))
row_label_vec = np.arange(num_classes).reshape((1, num_classes))
label_mat_ = np.tile(row_label_vec, (num_samples, 1))
if weightage == "linear":
weight = np.abs(label_mat - label_mat_)
else:
weight = (label_mat - label_mat_) ** 2
numerator = (y_pred * weight).sum()
label_dist = y_true.sum(axis=0, keepdims=True)
pred_dist = y_pred.sum(axis=0, keepdims=True)

col_label_vec = row_label_vec.T
row_mat = np.tile(row_label_vec, (num_classes, 1))
col_mat = np.tile(col_label_vec, (1, num_classes))
if weightage == "quadratic":
weight_ = (col_mat - row_mat) ** 2
else:
weight_ = np.abs(col_mat - row_mat)
weighted_pred_dist = np.matmul(weight_, pred_dist.T)
denominator = np.matmul(label_dist, weighted_pred_dist).sum()
denominator /= num_samples
return np.log(np.nan_to_num(numerator / denominator) + eps)


def gen_labels_and_preds(num_samples, num_classes, seed):
np.random.seed(seed)
rands = np.random.uniform(size=(num_samples, num_classes))
cat_labels = rands.argmax(axis=1)
y_true = np.eye(num_classes, dtype="int")[cat_labels]
y_pred = np.random.uniform(size=(num_samples, num_classes))
y_pred /= y_pred.sum(axis=1, keepdims=True)
return y_true, y_pred


@pytest.mark.parametrize("np_seed", [0, 1, 2, 3])
def test_linear_weighted_kappa_loss(np_seed):
y_true, y_pred = gen_labels_and_preds(50, 4, np_seed)
kappa_loss = WeightedKappaLoss(num_classes=4, weightage="linear")
y_pred = y_pred.astype(kappa_loss.dtype.as_numpy_dtype)
loss = kappa_loss(y_true, y_pred)
loss_np = weighted_kappa_loss_np(y_true, y_pred, weightage="linear")
np.testing.assert_allclose(loss, loss_np, rtol=1e-5, atol=1e-5)


@pytest.mark.parametrize("np_seed", [0, 1, 2, 3])
def test_quadratic_weighted_kappa_loss(np_seed):
y_true, y_pred = gen_labels_and_preds(100, 3, np_seed)
kappa_loss = WeightedKappaLoss(num_classes=3)
y_pred = y_pred.astype(kappa_loss.dtype.as_numpy_dtype)
loss = kappa_loss(y_true, y_pred)
loss_np = weighted_kappa_loss_np(y_true, y_pred)
np.testing.assert_allclose(loss, loss_np, rtol=1e-5, atol=1e-5)


def test_config():
kappa_loss = WeightedKappaLoss(
num_classes=4, weightage="linear", name="kappa_loss", epsilon=0.001,
)
assert kappa_loss.num_classes == 4
assert kappa_loss.weightage == "linear"
assert kappa_loss.name == "kappa_loss"
np.testing.assert_allclose(kappa_loss.epsilon, 0.001, 1e-6)


def test_serialization():
loss = WeightedKappaLoss(num_classes=3)
tf.keras.losses.deserialize(tf.keras.losses.serialize(loss))

0 comments on commit 6810fb3

Please sign in to comment.