Skip to content

Latest commit

 

History

History
194 lines (137 loc) · 10.6 KB

File metadata and controls

194 lines (137 loc) · 10.6 KB

Recurrent Neural Network (RNN) Intuition

What are Recurrent Neural Networks Used for?

RNNs are used in problems where knowing the previous context is important to determine the next output. For example, if a text predictor wants to predict the next word on the phrase "Regarding food, I love _", the previous words "Regarding food, I love" are critical context for predicting what comes next.

The technical term for this type of problems is Time Series Analysis.

Natural Language Processing is a field that heavily uses RNNs since they naturally model some key characteristics of natural language that other types of networks struggle with:

  • In natural language, the word order is critical in giving the text meaning. There is no easy way to model order dependencies in ANNs or CNNs.
  • Natural Language has "long distance effects". These are inherently a time series problem.
    • e.g. "I, the president of the United States, AM going to declare". In this example, "AM" is dependent on the word "I".
  • Inputs in many NLP problems need to be of varying length because texts are of varying length. ANNs and CNNs only take fixed size inputs.

Structure of a RNN

RNNs are typically drawn in 2 forms: rolled and un-rolled versions. The figure below explains the structure using the un-rolled representation.

Simplified RNN Structure

The key characteristic of RNNs is that neurons are connected to themselves through time. This means that a neuron at time t+1 uses 2 sources of inputs:

  1. The input from the input sequence.
  2. The output from time t.

Don't be fooled by the simplified flat representation above. RNNs are multidimensional. A more precise representation of how they are structured would be: Multidimensional RNN Structure

Example Applications of RNNs

RNN architecture varies depending on the problem. Some RNNs take a single input, some others take a sequence of inputs. A similar thing happens with the outputs.

The diagram below shows multiple variants of RNN architectures being applied to solve different types of problems.

Sample RNN applications

The Difficulties of Training Deep Networks: The Vanishing and Exploding Gradient Problems

Neural networks get increasingly harder to train as we add more layers. The gradient, the main piece of information for updating the weights, gets unstable as the number of layers increase due to compounded derivative calculations.

  • Remember that the gradients in layer l are calculated as the product of the derivatives for all layers from l+1 to the last layer. For example, in a 100-deep network, gradient_l1 ~= some_derivative_l2 * some_derivative_l3 * ... some_derivative_l100.
  • Vanishing Gradient: When many some_derivative_lX are < 1 their product will be very close to 0. This results in gradient_l1 being close to 0, so the weights in l1 barely get updated during back propagation and get stuck. This means that early layers in the network are hard to train and require many epochs and lots of training data.
  • Exploding Gradient: When many some_derivative_lX are > 1 their product will be very big. This results in gradient_l1 being very big and causing a violent and unstable update of the weights in l1. This big "jumps" in weights make the training unstable and the network might never find the optimal due to these swings.

If this doesn't make any sense, I suggest DeepLizard's youtube video on this topic.

This problem is not exclusive to RNNs. However, RNNs are very prone to it as their recursive nature makes them "arbitrarily deep".

The vanishing gradient problem in an RNN

Solutions to these problems

In general, both problems are governed by the same mechanism and can be partially solved by doing smart weight initialization that ensures that the gradients don't get too small or too big.

A very popular way of doing this "smart" initialization is called Xavier Initialization (aka Glorot Initialization). Xavier Initialization is a per layer initialization technique that initializes the weights randomly with a normal distribution around 0 and standard deviation that changes per layer. To know more details about this initialization see DeepLizard's video on it.

Keras supports Xavier (Glorot) initialization as shown next. Moreover, if we don't specify the type of kernel_initializer, it will use glorot_uniform by default.

Dense(32, activation=relu, kernel_initializer='glorot_uniform')

There are other solutions to these problems that are specific for the type of problem we face: explosion or vanishing. The details of most of these solutions are out of the scope of these notes.

Some specific solutions for the Exploding Gradient Problem

  • Truncated Back Propagation: stop back propagation after some amount of layers.
  • Penalties
  • Gradient Clipping: Set a fixed max value for the gradient and back propagate that value when the calculated gradient exceeds it.

Some specific solutions for the Vanishing Gradient Problem

  • Echo State Networks: A network architecture that helps with the problem (out of scope of this summary).
  • Long Short-Term Memory (LSTM) Networks: Another network architecture that tackles this problem. This architecture is very popular for RNNs and we will discuss it in detail in the next section.

Long Short-Term Memory (LSTM) Networks

This summary is heavily based on the following article. I strongly recommend reading it: Understanding LSTM Networks by Christopher Olah

LSTM networks are a particular architecture of the more general RNN category. They are of particular importance because almost all exciting results from RNNs have been achieved using this architecture.

Motivation

An RNNs main feature taking into account previous inputs and predictions to perform the current time-step prediction.

When the previous relevant context happened recently, "Vanilla" RNNs have no problem using that to inform the current prediction. For example, in the if we were predicting the next word in the phrase "the clouds are in the sky", "clouds" happens very close to the place where sky is being predicted.

However, in many applications (including text) context has a long distance effect. In these circumstances, "Vanilla RNNs" struggle to "remember" what happened many time steps ago.

Continuing with the text prediction example, if the phrase was "I grew up in France, the birthplace of civil liberties. I speak fluent French". "France" is critical context to predict French, but it happened so long ago that a "Vanilla RNN" won't be able to "remember" it to narrow down the prediction. This is why it is said that Vanilla RNNs have short term memory.

  • We've been talking about the short term memory problems of "Vanilla RNNs" from the inference point of view. However, they have the same problem at training time. During training, "Vanilla RNNs" are unable to connect related concepts if they have a long distance between them.

LSTM Intuition

LSTMs are specifically designed to tackle the short-term memory problem. In fact, having long-term memory is their default behaviour.

An LSTM is a group of interconnected learned single-layer neural networks and point-wise operations that are grouped together in a "repeating group".

Each of these neural networks and point-wise operations serve a particular purpose in the architecture and are explained below. For now, lets get familiar with the notation:

LSTM notation

The Memory Pipeline - The core idea

The core idea behind LSTMs is the "memory cell / pipeline". Through the memory pipeline, information flows easily and mostly unchanged across the multiple steps of the RNN. Changes to the memory only happen when the gates determine so (see below).

Side note: Intuitively, since the information in the pipeline doesn't change much, the derivative over time stays close to 1. This is why LSTMs are less prone to suffer from vanishing or exploding gradient problems.

The memory Pipeline - The core idea behind LSTMs

Step-by-step walk through

LSTM walk through part 1 LSTM walk through part 2

LSTM Variations

There are many variants of the LSTM architecture shown above. Most of them are slight variations of this architecture and have similar performance.

Beyond LSTMs - Attention Models

LSTMs were a step forward in RNN performance. The next big step in RNN performance are attention RNNs. Chirstopher Colah has an article on Attention RNNs and other forms of RNN augmentation.

Other Additional Reading