Skip to content

Latest commit



138 lines (108 loc) · 4.73 KB

File metadata and controls

138 lines (108 loc) · 4.73 KB

Phased LSTM implementation in Tensorflow

Welcome to the Tensorflow implementation of the recently introduced Phased LSTM by Neil et. al @ NIPS 2016

You can find here the original implementation from Daniel Neil (in Theano)


Here is the results of frequency discrimination at high resolution sampling (described in the paper): in dark blue you see PLSTM that converges way faster then GRU (in light blue)



Here I implemented the PLSTM in a plug-and-play fashion such that if you wanna use it in one of your models you can switch from LSTMCell/GRUCell to PhasedLSTMCell.

The core of the PLSTM is

    tau = vs.get_variable(
        "T", shape=[self._num_units],
        initializer=random_exp_initializer(0, self.tau_init), dtype=dtype)

    r_on = vs.get_variable(
        "R", shape=[self._num_units],
        initializer=init_ops.constant_initializer(self.r_on_init), dtype=dtype)

    s = vs.get_variable(
        "S", shape=[self._num_units],
        initializer=init_ops.random_uniform_initializer(0., tau.initialized_value()), dtype=dtype)
        # for backward compatibility (v < 0.12.0) use the following line instead of the above
        # initializer = init_ops.random_uniform_initializer(0., tau), dtype = dtype)

    tau_broadcast = tf.expand_dims(tau, dim=0)
    r_on_broadcast = tf.expand_dims(r_on, dim=0)
    s_broadcast = tf.expand_dims(s, dim=0)

    r_on_broadcast = tf.abs(r_on_broadcast)
    tau_broadcast = tf.abs(tau_broadcast)
    times = tf.tile(times, [1, self._num_units])

    # calculate kronos gate
    phi = tf.div(tf.mod(tf.mod(times - s_broadcast, tau_broadcast) + tau_broadcast, tau_broadcast),
    is_up = tf.less(phi, (r_on_broadcast * 0.5))
    is_down = tf.logical_and(tf.less(phi, r_on_broadcast), tf.logical_not(is_up))

    k =, phi / (r_on_broadcast * 0.5),
        , 2. - 2. * (phi / r_on_broadcast), self.alpha * phi))

then the kronos gate is applied to the cell simply by

        c = k * c + (1. - k) * c_prev
        m = k * m + (1. - k) * m_prev

PhasedLSTMCell has the same parameters set has the LSTMCell plus, here I report the default parameters (as indicated by the paper)

  • The slope in the off period of the gate
  • The initial value of r_on
  • The parameter for the initial sampling of tau

tau is sampled as ~exp(uniform(0, tau_init))

Notes for backward compatibility (v < 0.12.0)

The current implementation uses Tensorflow 0.12.0. If you don't wanna update Tensorflow (BTW you should :)) I inserted in the code some commented lines to be backward compatible.


    initializer=init_ops.random_uniform_initializer(0., tau.initialized_value()), dtype=dtype)
    # for backward compatibility (v < 0.12.0) use the following line instead of the above
    # initializer = init_ops.random_uniform_initializer(0., tau), dtype = dtype)

In the unit test
    # for backward compatibility (v < 0.12.0) use the following line instead of the above
    # initialize_all_variables(sess)

Remember also to change the summaries calls, that is:

    tf.summary.scalar -> tf.scalar_summary 
    tf.summary.histogram -> tf.histogram_summary 
    tf.summary.FileWriter -> tf.train.SummaryWriter 
    tf.summary.merge -> tf.merge_summary 

Paper's Task

I implemented the first task described in the paper, that is frequency discrimination. The network is presented with sine waves and has to discriminate between waves of a target range of frequencies (e.g. 5-6 Hz) and waves outside of this range. Furthermore there are three different ways in which you can sample these sine waves:

  • Low resolution (1 ms)
  • High resolution (0.1 ms)
  • Asynchronously

The 3 ways are implemented and you can select them with the flags.

Update 17-02-2017

Transition to tensorflow v1.0 The folder has now two different scripts to support v1.0 and older To use v1.0 refer to PLSTM_v1

Updated also

    outputs = multiPLSTM(cells, inputs, lens, n_input, initial_states)

now the function takes a list of cells previously generated, in this way you can paramterize every cell separately.

The function supports also in place copy of initial states


Let me know if you encounter any problem:
