From 2bd05b065ce1334abe853634041276c768c78d95 Mon Sep 17 00:00:00 2001 From: Shuyu Cheng Date: Wed, 10 Oct 2018 15:26:02 +0800 Subject: [PATCH] Add shape attribute to StochasticTensor to fix the error when running vae_conv --- tests/model/test_base.py | 6 +++++- zhusuan/model/base.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/model/test_base.py b/tests/model/test_base.py index dc2a333..70c3a9a 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -16,7 +16,9 @@ class TestStochasticTensor(tf.test.TestCase): def test_init(self): - samples = Mock() + static_shape = Mock() + get_shape_func = Mock(return_value=static_shape) + samples = Mock(get_shape=get_shape_func) log_probs = Mock() probs = Mock() sample_func = Mock(return_value=samples) @@ -35,6 +37,8 @@ def test_init(self): self.assertTrue(s_tensor.tensor is samples) self.assertTrue(s_tensor.log_prob(None) is log_probs) self.assertTrue(s_tensor.prob(None) is probs) + self.assertTrue(s_tensor.get_shape() is static_shape) + self.assertTrue(s_tensor.shape is static_shape) obs_int32 = tf.placeholder(tf.int32, None) obs_float32 = tf.placeholder(tf.float32, None) diff --git a/zhusuan/model/base.py b/zhusuan/model/base.py index be7a690..7732294 100644 --- a/zhusuan/model/base.py +++ b/zhusuan/model/base.py @@ -118,7 +118,16 @@ def tensor(self): self._tensor = self.sample(self._n_samples) return self._tensor + @property + def shape(self): + return self.get_shape() + def get_shape(self): + """ + Static :attr:`shape`. + + :return: A `TensorShape` instance. + """ return self.tensor.get_shape() def sample(self, n_samples):