diff --git a/tests/model/test_base.py b/tests/model/test_base.py index dc2a333..70c3a9a 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -16,7 +16,9 @@ class TestStochasticTensor(tf.test.TestCase): def test_init(self): - samples = Mock() + static_shape = Mock() + get_shape_func = Mock(return_value=static_shape) + samples = Mock(get_shape=get_shape_func) log_probs = Mock() probs = Mock() sample_func = Mock(return_value=samples) @@ -35,6 +37,8 @@ def test_init(self): self.assertTrue(s_tensor.tensor is samples) self.assertTrue(s_tensor.log_prob(None) is log_probs) self.assertTrue(s_tensor.prob(None) is probs) + self.assertTrue(s_tensor.get_shape() is static_shape) + self.assertTrue(s_tensor.shape is static_shape) obs_int32 = tf.placeholder(tf.int32, None) obs_float32 = tf.placeholder(tf.float32, None) diff --git a/zhusuan/model/base.py b/zhusuan/model/base.py index be7a690..7732294 100644 --- a/zhusuan/model/base.py +++ b/zhusuan/model/base.py @@ -118,7 +118,16 @@ def tensor(self): self._tensor = self.sample(self._n_samples) return self._tensor + @property + def shape(self): + return self.get_shape() + def get_shape(self): + """ + Static :attr:`shape`. + + :return: A `TensorShape` instance. + """ return self.tensor.get_shape() def sample(self, n_samples):