Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions edward2/jax/nn/random_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
[3]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel
Machines. In _Neural Information Processing Systems_, 2007.
https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf
[4]: Zhiyun Lu, Eugene Ie, Fei Sha. Uncertainty Estimation with Infinitesimal
Jackknife. _arXiv preprint arXiv:2006.07584_, 2020.
https://arxiv.org/abs/2006.07584
"""
import dataclasses
import functools
Expand Down
55 changes: 54 additions & 1 deletion edward2/jax/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

"""JAX layer and utils."""

from typing import Iterable, Callable
from typing import Callable, Iterable, Optional

from jax import random
import jax.numpy as jnp

Array = jnp.ndarray
DType = type(jnp.float32)
InitializeFn = Callable[[jnp.ndarray, Iterable[int], DType], jnp.ndarray]

Expand Down Expand Up @@ -48,3 +49,55 @@ def initializer(key, shape, dtype=jnp.float32):
x = random.normal(key, shape, dtype) * (-random_sign_init) + 1.0
return x.astype(dtype)
return initializer


def mean_field_logits(logits: Array,
covmat: Optional[Array] = None,
mean_field_factor: float = 1.,
likelihood: str = 'logistic'):
"""Adjust the model logits so its softmax approximates the posterior mean [4].

Arguments:
logits: A float ndarray of shape (batch_size, num_classes).
covmat: A float ndarray of shape (batch_size, ). If None then it is assumed
to be a vector of 1.'s.
mean_field_factor: The scale factor for mean-field approximation, used to
adjust the influence of posterior variance in posterior mean
approximation. If covmat=None then it is used as the scaling parameter for
temperature scaling.
likelihood: name of the likelihood for integration in Gaussian-approximated
latent posterior. Must be one of ('logistic', 'binary_logistic',
'poisson').

Returns:
A float ndarray of uncertainty-adjusted logits, shape
(batch_size, num_classes).

Raises:
(ValueError) If likelihood is not one of ('logistic', 'binary_logistic',
'poisson').
"""
if likelihood not in ('logistic', 'binary_logistic', 'poisson'):
raise ValueError(
f'Likelihood" must be one of (\'logistic\', \'binary_logistic\', \'poisson\'), got {likelihood}.'
)

if mean_field_factor < 0:
return logits

# Defines predictive variance.
variances = 1. if covmat is None else covmat

# Computes scaling coefficient for mean-field approximation.
if likelihood == 'poisson':
logits_scale = jnp.exp(-variances * mean_field_factor / 2.) # pylint:disable=invalid-unary-operand-type
else:
logits_scale = jnp.sqrt(1. + variances * mean_field_factor)

# Pads logits_scale to compatible dimension.
while logits_scale.ndim < logits.ndim:
logits_scale = jnp.expand_dims(logits_scale, axis=-1)

return logits / logits_scale


75 changes: 75 additions & 0 deletions edward2/jax/nn/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# coding=utf-8
# Copyright 2021 The Edward2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for utils."""
from absl.testing import absltest
from absl.testing import parameterized

import edward2.jax as ed

import jax
import jax.numpy as jnp

import numpy as np
import tensorflow as tf


class MeanFieldLogitsTest(parameterized.TestCase, tf.test.TestCase):

def testMeanFieldLogitsLikelihood(self):
"""Tests if scaling is correct under different likelihood."""
batch_size = 10
num_classes = 12
variance = 1.5
mean_field_factor = 2.

rng_key = jax.random.PRNGKey(0)
logits = jax.random.normal(rng_key, (batch_size, num_classes))
covmat = jnp.ones(batch_size) * variance

logits_logistic = ed.nn.utils.mean_field_logits(
logits, covmat, mean_field_factor=mean_field_factor)
logits_poisson = ed.nn.utils.mean_field_logits(
logits,
covmat,
mean_field_factor=mean_field_factor,
likelihood='poisson')

self.assertAllClose(logits_logistic, logits / 2., atol=1e-4)
self.assertAllClose(logits_poisson, logits * np.exp(1.5), atol=1e-4)

def testMeanFieldLogitsTemperatureScaling(self):
"""Tests using mean_field_logits as temperature scaling method."""
batch_size = 10
num_classes = 12

rng_key = jax.random.PRNGKey(0)
logits = jax.random.normal(rng_key, (batch_size, num_classes))

# Test if there's no change to logits when mean_field_factor < 0.
logits_no_change = ed.nn.utils.mean_field_logits(
logits, covmat=None, mean_field_factor=-1)

# Test if mean_field_logits functions as a temperature scaling method when
# mean_field_factor > 0, with temperature = sqrt(1. + mean_field_factor).
logits_scale_by_two = ed.nn.utils.mean_field_logits(
logits, covmat=None, mean_field_factor=3.)

self.assertAllClose(logits_no_change, logits, atol=1e-4)
self.assertAllClose(logits_scale_by_two, logits / 2., atol=1e-4)


if __name__ == '__main__':
absltest.main()