A simple overview of RNN, LSTM and Attention Mechanism

Recurrent Neural Networks, Long Short Term Memory and the famous Attention based approach explained

When you delve into the text of a book, you read in the logical order of chapter and pages and for a good reason. The ideas you form, the train of thoughts, it’s all dependent on what you have understood and retained up to a given point in the book. This persistence or the ability to have some memory and pay attention helps us to develop understanding of concepts and the world around us, allowing us to think, use our knowledge, to write, solve problems as well as innovate.

There is an inherent notion of progress with steps, with passage of time, for sequential data. Sequential memory is a mechanism that makes it easier for our brain to recognize sequence patterns.

Traditional Neural Networks, despite being good at what they do, lack this intuitive innate tendency for persistence. How would a simple feed forward Neural Network read a sentence and pass on previously gathered information in order to completely understand and relate sequence of incoming data? It cannot (mind it, training NN for weights is not same as persistence of information for next step).

RNNs to the rescue

There’s something magical about Recurrent Neural Networks.

— Andrej Karpathy

Recurrent Neural Networks address this drawback of vanilla NNs with a simple yet elegant mechanism and are great at modeling sequential data.

rolled RNN

Does RNN look weird to you? Let me explain and remove the confusion.

Take a simple feed forward Neural Network first, shown below. It has the input coming in, red dot, to the hidden layer, blue dot, which results in black dot output.

A simple NN

An RNN feeds it’s output to itself at next time-step, forming a loop, passing down much needed information.

RNN feeding hidden state value to itself

To better understand the flow, look at the unrolled version below, where each RNN has different input (token in a sequence) and output at each time step.

Unrolled RNN, from time step 0 to t

The NN A takes in input at each time step while giving output h and passing information to itself for next incoming input t+1 step. The incoming sequential data is encoded using RNN first before being utilized to determine the intent/action via another feed forward network for decision.

RNNs have become the go-to NNs to be used for various tasks involving notion of sequential data, such as: speech recognition, language modeling, translation, image captioning etc.

Let’s say we ask a question to your in-house developed AI Assistant named Piri (or whatever ‘iri’ you prefer), “what time is it?”, here we try to break the sequence and color code it.

How RNNs work for the tokens of sequence

Final query retrieved as a result of processing the entire sequence

Memory: An essential requirement for making Neural Networks smart(er)

Humans tend to retrieve information from memory, short or long, use current information with it and derive logic to take next action (or impulse/habit, again based on previous experiences).

Similar is the idea to make RNN hold on to previous information or state(s). As the output of a recurrent neuron, at a given time step t, is clearly a function of the previous input (or think of it as previous input with accumulated information) till time step t-1, one could consider this mechanism as a form of memory. Any part of a neural network that has the notion of preserving state, even partially, across time steps is usually referred to as a memory-cell.

Each recurrent neuron has an output as well as a hidden state which is passed to next step neuron.

hidden state

Unrolled RNN with hidden state and output at each time step

Too good to be true ? Vanishing Gradients and the long dependency

Sometimes the use case requires just immediate information from previous step or only few steps back if not 1. Example: Clouds are in the sky . If sky here is the next word to be predicted using a language model then this becomes an easy task for an RNN as the gap or distance, or better: dependency, between the current step and previous information is small.

Short term dependency: RNNs works like a charm

As you can see the output at h3 required X0 and X1, which is a short-term dependency, an easy task for simple RNNs. Are RNNs totally infallible then?

What about large sequences and dependency? Theoretically, RNNs are capable of handling such “long-term dependencies.”

However, in practice, as the dependency gap increases and more of the context is required, RNNs are unable to learn and would result in defenestration for sure. These issues were analyzed by and .

You can see the changing distribution for colors with each time step, with the recent ones having more area while previous ones reducing further. This beautifully shows the issue with RNNs called as “short-term memory”.

The age old Vanishing Gradient problem: Major problem that makes training RNNs very slow and inefficient for usage is vanishing gradient issue. The process for a feed forward neural network is as follows: a) the forward pass outputs some results b) the results are used to compute the loss value c) the loss value is used to perform back propagation to calculate the gradients with respect to the weights d) these gradients with respect to weights flow backwards to fine tune the weights in order to improve the network’s performance.

As the manipulation of weights happens according to layer before it, small gradients tends to diminish by large margins after every layer and reach a point where they are very close to zero hence the learning drops for initial layers and overall effective training slows down.

vanishing gradient problem in NNs during back propagation across layers

Hence, vanishing gradients causes RNN to NOT learn the long-range dependencies well across time steps. This means the earlier tokens of a sequence will not be having high importance even if they are crucial to the entire context. Therefore, this inability to learn on long sequences results in short-term memory.

Long sequence for RNN raising long-term dependency issue, making RNNs struggle to retain all the needed information

This long-term dependency thing seems to encumber the usage of RNNs for real world application with decent size sequences. Does that mean RNNs are ostracized forever?

Not quite, as it turns out, a variation of RNNs or let’s say a special RNN comes handy for solving the long-term dependency issues encountered with simple RNNs.

All hail the eternal hero: LSTM

Long Short Term Memory (LSTM) helps for the “long-term dependency” requirement (or solves short-term memory issue) as the default behavior is to remember long term information.

From RNN to LSTM

Let’s say that RNN can remember at least the last step in a sequence of words.

Our data set of sentences:

Dog scares Cat.
Cat scares Mouse.
Mouse scares Dog.

When starting out, the RNN’s input will be empty token for beginning of each sentence (for words Dog, Cat and Mouse).

RNN model sees each of those words followed by scares and a .(full stop) after that, for each sentence.

If we try to predict word in sentences, the predictions could take even these form due to ambiguity:

<Empty> Dog scares Dog.

<Empty> Mouse scares Cat.

<Empty> Mouse scares Cat scares Dog.

Which DOES NOT make any sense as the previous context does not play any role in deciding the next suitable word.

The fundamental LSTM ideas:

First things first: the notations!

Notations used to explain LSTM

The primary component that makes LSTMs rock is the presence of a cell state/vector for each LSTM node, which is passed down to every node in the chain.

Cell State/Vector running on top of a LSTM cell

This cell state could be modified, if required, with linear mathematical interactions in every node depending upon the learned behavior, regulated by other inner components. In simple words, these inner components, known as gates with activation functions, can add or remove information to the cell state, thus helping in moving forward relevant (sometimes modified) information.

A simple RNN has a simple NN in itself acting as a sole gate for some data manipulations, LSTM, however, has a more intricate inner structure of 4 gates.

NOTE: The LSTM does have the ability to remove or add information to the cell state, carefully regulated by structures called gates, making them inherently better than simple RNNs.

LSTM chain of 3 nodes and internal structure depiction

The gates allow information to be optionally altered and let through. For this Sigmoid and Tanh gates are used internally.

Before we go through step-by-step working of LSTM cell, let’s take a look at what Sigmoid and Tanh activation functions are:

Sigmoid activation:

The sigmoid helps to squash the incoming values between 0 and 1 ([0, 1]). This is used in conjugation of other component to stop(if 0) or allow(1) incoming information.

Sigmoid squashing the input values

Tanh activation:

Tanh on other hand helps to compress the values in range -1 to +1.

This helps to add or remove data with -1 and +1 respectively.

These activation functions also help to curb the outputs which otherwise would just explode after successive multiplications along the chain.

LSTM inner workings 🧐

Step 1: To decide what to keep and what to FORGET

First step is to decide what all should be forgotten from the cell state. To solve this a Sigmoid is used in forget gate layer. As explained previously a Sigmoid helps to output values in closed range [0,1]. Here 0 means to ‘remove completely’ while 1 means to ‘keep information’ as it is.

The forget gate looks at hₜ₋₁ and xₜ, and outputs a number between 0 and 1 for each number in the cell state Cₜ₋₁.

Forget gate layer

Step 2: What new information to add back

Second step involves deciding what new data should be added back to the cell state. This involves 2 parts and are used in conjunction.

a) Sigmoid gate layer deciding what exactly to change (with 0 to 1 range)

b) Tanh gate layer creating candidate values that could be added to the cell state (with -1 to +1 range)

Sigmoid and Tanh deciding what to manipulate and by how much at input gate layer

Step 3: To actually update the cell state with information from steps 1 and 2

This step includes manipulating the incoming cell state Cₜ₋₁ to reflect the decisions from forget (step 1) and input gate (step 2) layers with help of point-wise multiplication and addition respectively.

arrying out the calculations to actually update cell state

Here fₜ is for point-wise multiplication with Cₜ₋₁ to forget information from context vector and iₜ * Cₜ is new candidate values scaled by how much they needs to be updated.

Step 4: What’s the output

Output gate calculations

The output is a modified version of the cell state. For this sigmoid decides what part needs to be modified, which is multiplied with output after Tanh (used for scaling), resulting in hidden state output.

RNNs and LSTMs have been used widely for processing and making sense of sequential data. The usual approach is the encoder-decoder architecture for seq2seq tasks.

Sequence-to-sequence (seq2seq) models in NLP are used to convert sequences of Type A to sequences of Type B. For example, translation of English sentences to German sentences is a sequence-to-sequence task.

The encoder has an input sequence x1, x2, x3, x4. We denote the encoder states by c1, c2, c3. The encoder outputs a single output vector c which is passed as input to the decoder. Like the encoder, the decoder is also a single-layered RNN, we denote the decoder states by s1, s2, s3 and the network’s output by y1, y2, y3, y4.

Attention Mechanism: A superhero! But why? 🤔

The primary question to (re)consider — what exactly is being demanded from RNNs and LSTMs (in encoder-decoder architecture): current node or state has access to information for whole input seen so far (i.e. the information flowing from t₀ till t-₁ is available in some modified/partial form for state at time step t), therefore, the final state of RNN (better to say encoder) must hold information for the entire input sequence.

Simple representation of traditional Encoder-Decoder architecture using RNN/LSTM

A major drawback with this architecture lies in the fact that the encoding step needs to represent the entire input sequence x1, x2, x3, x4 as a single vector c, which can cause information loss as all information needs to be compressed into c. Moreover, the decoder needs to decipher the passed information from this single vector only, a highly complex task in itself indeed.

And this is exactly where upper bound on model’s understanding potential as well as performance sets in quickly as this may be too much to ask from simple encoders and decoders for sequential processing of long input sequences.

Attention Mechanism for sequence modelling was first used in the paper: Although the notion of ‘attention’ wasn’t as famous back then and the word has been sparingly used within the paper itself.

As explained in the paper [1]

A potential issue with this encoder–decoder approach is that a neural network needs to be able to compress all the necessary information of a source sentence into a fixed-length vector. This may make it difficult for the neural network to cope with long sentences, especially those that are longer than the sentences in the training corpus.

Attention mechanism helps to look at all hidden states from encoder sequence for making predictions unlike vanilla Encoder-Decoder approach.

Simple vs Attention based encoder-decoder architectures

The difference in preceding image explained: In a simple Encoder-Decoder architecture the decoder is supposed to start making predictions by looking only at the final output of the encoder step which has condensed information. On the other hand, attention based architecture attends every hidden state from each encoder node at every time step and then makes predictions after deciding which one is more informative.

Wait !!

But how does it decide which states are more or less useful for every prediction at every time step of decoder?

Solution: Use a Neural Network to “learn” which hidden encoder states to “attend” to and by how much.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store