Introduction to Attention Mechanism in Deep Learning — ELI5 Way
In this article, we will discuss some of the limitations of Encoder-Decoder models which act as a motivation for the development of Attention Mechanism. After that, we will talk about the concepts of Attention Models and their application in Machine translation.
Note: Attention mechanism is a slightly advanced topic and requires an understanding of Encoder-Decoder and Long Short Term Memory Models. Kindly refer to my previous articles on Encoder-Decoder Models and LSTM
Before we discuss the concepts of Attention models, we will start by revisiting the task of machine translation using Encoder-Decoder models.
Citation Note: The content and the structure of this article is based on my understanding of the deep learning lectures from One-Fourth Labs — PadhAI.
Machine Translation — Recap
Let’s take an example of translating text from Hindi to English using the seq2seq model which uses an encoder-decoder architecture. The underlying model in the encoder-decoder architecture can be anything from RNN to LSTM.
At a high level, encoder reads the entire sentence only once and encodes all the information from the previously hidden representations and previous inputs into an encoded vector. Then the decoder at each time-step uses this embedding to produce a new word.
The problem in this approach is that encoder reads the entire sentence only once and it has to remember everything and converts that sentence to an encoded vector. For longer sentences, the encoder will not be able to remember the starting parts of the sequence resulting in the loss of information.
Is this how humans translate a sentence?
Do you think that the entire input sequence (or sentence) is important at every time-step during encoding?. Can we place special emphasis on certain words rather than giving equal importance to all the words?. Attention Mechanism is developed to address these challenges.
We, humans, try to translate each word in the output by focusing only on certain words in the input. At each time-step, we take only relevant information from the long sentences and then translate that particular word. Ideally, at each time-step, we should feed only the relevant information (encodings of the relevant information) to the decoder for the translation.
Attention Mechanism — Oracle
How do we know which of the words are important or we need to give more attention to?. For now, assume that we have an oracle to tell us which words to focus on at a given time-step t. By taking the oracle’s help can we design a better architecture so that we can feed relevant information to the decoder?.
So for each input word, we assign a weight α (ranges between 0–1) that represents the importance of that word for the output at the time-step ‘t’. For example, α12 represents the importance of the first input word on the output word at the second time-step. To generalize, the representation αjt represents the weight associated with the jᵗʰ input word at the tᵗʰ time-step.
For example, at time-step 2, we could just take a weighted average of the corresponding word representations along with the weights αjt and feed it into the decoder. In this scenario, we are not feeding the complete encoded vector into the decoder, rather the weighted representation of the words. In effect, we are giving more importance or attention to the important words based on the weights given by oracle. (Thanks to oracle!)
Intuitively this approach should work better than the vanilla version of encoder-decoder architecture because we are not overloading the decoder with irrelevant information.
Model for Attention
Don’t be fooled, in reality, there is no oracle. If there is no oracle then how do we learn the weights?.
Notations: From now on we will refer the decoder state at the tᵗʰ time-step as St and encoder state at the jᵗʰ time-step as hⱼ.
The parameter αjt has to be learned from the data. To enable this we define a function,
The function to calculate the intermediate parameter (ejt) takes two parameters. Let’s discuss what are those parameters. At the tᵗʰ time-step, we are trying to find out how important is the jᵗʰ word, so the function to compute the weights should depend on the vector representation of the word itself (i.e… hⱼ) and the decoder state up to that particular time step (i.e…St-₁).
The weight ejt captures the importance of the jᵗʰ input word for decoding the tᵗʰ output word. Using the softmax function, we can normalize these weights to get our parameter αjt (ranges between 0–1).
The parameter αjt denotes the probability of focusing on the jᵗʰ word to produce the tᵗʰ output word.
In the previous section, we have discussed how to learn the parameter αjt using a function at a high level which takes two arguments — Decoder state before the tᵗʰ time-step (St-₁) and the vector representation of the word (hⱼ). The output of this function is normalized using softmax to obtain αjt.
In this section, we will define the parametric form for ejt such that we should be able to learn this parameter from the data. The most commonly used parametric form or function to compute the ejt is given below,
To learn the parameter ejt, we have introduced additional parameters Vₐₜₜ, Uₐₜₜ and Wₐₜₜ. Where Uₐₜₜ denotes weights associated with the input of an encoder, Wₐₜₜ denotes weights associated with the decoder hidden state and Vₐₜₜ denotes weights associated with the output of a decoder. These parameters will also be learned along with other parameters of the encoder-decoder model.
Machine Translation — Attention Mechanism
In the previous section, we were able to define the parametric function to learn the weights (ejt before normalizing) to give more attention to particular words. In this section, we will discuss the task of Machine Translation end to end using the Attention Mechanism.
In this task, we are translating the input from Hindi to English. For simplicity, we are assuming that RNN’s are being used as Encoder and Decoder Models. But you can use LSTM or any variant of GRU’s also.
- The encoder operation doesn’t change much when we compare it to the vanilla version of encoder-decoder architecture without attention.
- At each time step, the representation of each word is computed as a function of the output of the previous time step and current input along with bias.
- The final hidden state vector(sₜ) contains all the encoded information from the previous hidden representations and previous inputs.
- RNN is used as an encoder.
- In the vanilla version of the encoder-decoder model, we would pass the entire encoded vector to the output layer, which decodes into the probability distribution of the next possible word.
- Instead of passing the entire encoded vector, we need to find the attention weights using the fancy equation that we discuss in the last section to find ejt. Then normalize the ejt weights using softmax function to get αjt.
- Once we have all the inputs to feed into the decoder and weights associated with them (Thanks to the fancy equation!), we will compute the weighted combination of all the inputs and weights to get the resultant vector Ct.
- We will feed the weighted combination vector Ct to the Decoder RNN, which decodes into the probability distribution of the next possible word. This operation of decoding goes for all the time-steps present in the input.
- The output layer is a softmax function and it takes hidden state representation and weights associated with it along with the bias as the inputs.
These models are called Encode-Attend-Decode Models or also known as Seq2Seq with Attention.
Recommended Reading — The ELI5 Project MachineLearningIntroduction to Encoder-Decoder Models — ELI5 WayDiscuss the basic concepts of Encoder-Decoder models and it’s applications in some of the tasks like language modeling…towardsdatascience.comLong Short Term Memory and Gated Recurrent Unit’s Explained — ELI5 WayIn this post, we will learn the intuition behind the working of LSTM and GRU.towardsdatascience.com
In this post, we discussed some of the limitations of the vanilla version of the encoder-decoder model in machine translation. For longer sentences, the encoder will not be able to remember the starting parts of the sequence resulting in the loss of information. After that, we looked at how to feed only relevant information to the decoder or give more attention to the important words that help to retain information over longer sequences.
From there, we discussed the parametric function to learn the attention weights needed to give more importance for certain words during decoding. Finally, we looked at the end-to-end machine translation task using an attention mechanism.
In my next post, we will discuss the implementation of the Attention Mechanism using Pytorch. So make sure you follow me on Medium to get notified as soon as it drops.
Until then, Peace 🙂
Niranjan Kumar is Senior Consultant Data Science at Allstate India. He is passionate about Deep Learning and Artificial Intelligence. Apart from writing on Medium, he also writes for Marktechpost.com as a freelance data science writer. Check out his articles here.
You can connect with him on LinkedIn or follow him on Twitter for updates about upcoming articles on deep learning and machine learning.