From f783ad71ce4ceda10d5ed2eaf0c573f5186379e8 Mon Sep 17 00:00:00 2001 From: Jeremiah Liu Date: Tue, 3 Aug 2021 10:13:53 -0700 Subject: [PATCH] Removes unnecessary ViT-GP hyper-parameters. Due to [pull #489](https://github.com/google/edward2/pull/489) to `edward2.jax.nn.RandomFeatureGaussianProcess`. Some of the special hyper-parameter configs are no longer needed. Therefore we remove them to simplify the model API. PiperOrigin-RevId: 388484029 --- edward2/jax/nn/random_feature.py | 3 ++ edward2/jax/nn/utils.py | 55 ++++++++++++++++++++++- edward2/jax/nn/utils_test.py | 75 ++++++++++++++++++++++++++++++++ 3 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 edward2/jax/nn/utils_test.py diff --git a/edward2/jax/nn/random_feature.py b/edward2/jax/nn/random_feature.py index 27b27bb8..92b1e9ee 100644 --- a/edward2/jax/nn/random_feature.py +++ b/edward2/jax/nn/random_feature.py @@ -27,6 +27,9 @@ [3]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel Machines. In _Neural Information Processing Systems_, 2007. https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf +[4]: Zhiyun Lu, Eugene Ie, Fei Sha. Uncertainty Estimation with Infinitesimal + Jackknife. _arXiv preprint arXiv:2006.07584_, 2020. + https://arxiv.org/abs/2006.07584 """ import dataclasses import functools diff --git a/edward2/jax/nn/utils.py b/edward2/jax/nn/utils.py index cf7d405d..157e6e59 100644 --- a/edward2/jax/nn/utils.py +++ b/edward2/jax/nn/utils.py @@ -15,11 +15,12 @@ """JAX layer and utils.""" -from typing import Iterable, Callable +from typing import Callable, Iterable, Optional from jax import random import jax.numpy as jnp +Array = jnp.ndarray DType = type(jnp.float32) InitializeFn = Callable[[jnp.ndarray, Iterable[int], DType], jnp.ndarray] @@ -48,3 +49,55 @@ def initializer(key, shape, dtype=jnp.float32): x = random.normal(key, shape, dtype) * (-random_sign_init) + 1.0 return x.astype(dtype) return initializer + + +def mean_field_logits(logits: Array, + covmat: Optional[Array] = None, + mean_field_factor: float = 1., + likelihood: str = 'logistic'): + """Adjust the model logits so its softmax approximates the posterior mean [4]. + + Arguments: + logits: A float ndarray of shape (batch_size, num_classes). + covmat: A float ndarray of shape (batch_size, ). If None then it is assumed + to be a vector of 1.'s. + mean_field_factor: The scale factor for mean-field approximation, used to + adjust the influence of posterior variance in posterior mean + approximation. If covmat=None then it is used as the scaling parameter for + temperature scaling. + likelihood: name of the likelihood for integration in Gaussian-approximated + latent posterior. Must be one of ('logistic', 'binary_logistic', + 'poisson'). + + Returns: + A float ndarray of uncertainty-adjusted logits, shape + (batch_size, num_classes). + + Raises: + (ValueError) If likelihood is not one of ('logistic', 'binary_logistic', + 'poisson'). + """ + if likelihood not in ('logistic', 'binary_logistic', 'poisson'): + raise ValueError( + f'Likelihood" must be one of (\'logistic\', \'binary_logistic\', \'poisson\'), got {likelihood}.' + ) + + if mean_field_factor < 0: + return logits + + # Defines predictive variance. + variances = 1. if covmat is None else covmat + + # Computes scaling coefficient for mean-field approximation. + if likelihood == 'poisson': + logits_scale = jnp.exp(-variances * mean_field_factor / 2.) # pylint:disable=invalid-unary-operand-type + else: + logits_scale = jnp.sqrt(1. + variances * mean_field_factor) + + # Pads logits_scale to compatible dimension. + while logits_scale.ndim < logits.ndim: + logits_scale = jnp.expand_dims(logits_scale, axis=-1) + + return logits / logits_scale + + diff --git a/edward2/jax/nn/utils_test.py b/edward2/jax/nn/utils_test.py new file mode 100644 index 00000000..8eebf0d3 --- /dev/null +++ b/edward2/jax/nn/utils_test.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2021 The Edward2 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for utils.""" +from absl.testing import absltest +from absl.testing import parameterized + +import edward2.jax as ed + +import jax +import jax.numpy as jnp + +import numpy as np +import tensorflow as tf + + +class MeanFieldLogitsTest(parameterized.TestCase, tf.test.TestCase): + + def testMeanFieldLogitsLikelihood(self): + """Tests if scaling is correct under different likelihood.""" + batch_size = 10 + num_classes = 12 + variance = 1.5 + mean_field_factor = 2. + + rng_key = jax.random.PRNGKey(0) + logits = jax.random.normal(rng_key, (batch_size, num_classes)) + covmat = jnp.ones(batch_size) * variance + + logits_logistic = ed.nn.utils.mean_field_logits( + logits, covmat, mean_field_factor=mean_field_factor) + logits_poisson = ed.nn.utils.mean_field_logits( + logits, + covmat, + mean_field_factor=mean_field_factor, + likelihood='poisson') + + self.assertAllClose(logits_logistic, logits / 2., atol=1e-4) + self.assertAllClose(logits_poisson, logits * np.exp(1.5), atol=1e-4) + + def testMeanFieldLogitsTemperatureScaling(self): + """Tests using mean_field_logits as temperature scaling method.""" + batch_size = 10 + num_classes = 12 + + rng_key = jax.random.PRNGKey(0) + logits = jax.random.normal(rng_key, (batch_size, num_classes)) + + # Test if there's no change to logits when mean_field_factor < 0. + logits_no_change = ed.nn.utils.mean_field_logits( + logits, covmat=None, mean_field_factor=-1) + + # Test if mean_field_logits functions as a temperature scaling method when + # mean_field_factor > 0, with temperature = sqrt(1. + mean_field_factor). + logits_scale_by_two = ed.nn.utils.mean_field_logits( + logits, covmat=None, mean_field_factor=3.) + + self.assertAllClose(logits_no_change, logits, atol=1e-4) + self.assertAllClose(logits_scale_by_two, logits / 2., atol=1e-4) + + +if __name__ == '__main__': + absltest.main()