Stepping Through An LSTM Cell
Long Short-Term Memory (LSTM) cells are used in many modern deep learning architectures, especially those involving sequences of text (like machine translation) or time-series data. There are many good introductory blogs on LSTMs; for example from Christopher Olah and Andrej Karpathy. However, it is still instructive to step through a few time steps of hand-calculation for a memory cell and validate the results with actual code. My colleague Josh Whitney did just that and I offered to write it up for a broader audience.
LSTMs are a variant of Recurrent Neural Networks (RNNs) which preserve the error that can be backpropagated through time and layers.The DL4J website has a detailed description of Fig. 1 which is worthwhile reading. The three (analog) gates block or pass information based on weights which are learned just like other neural network weights through recurrent network learning (backpropagation and gradient descent). Let us call the three activations, associated with the three (recurrent) gates (to the right in Fig. 1), recurrent activations. The central identity activation is a summation which adds the cell’s current state with its input; this is the key to the ability of the LSTM to maintain a constant error during backpropagation. The black dots signify multiplication of the recurrent gate outputs with the output of the cell activations.
Fig. 2 is a simplified network that we use for our illustration. It has a single input neuron, identity activations and all weights set to 1. We provide an input sequence of length 3, that are each 1.
Let us first write some Python code using the Keras interface (to TensorFlow back-end) to derive the expected sequence output (sequence_out), and values for the hidden-state (state_h) and cell-state (state_c). Note that for our simple structure, the sequence_outis equal to state_h. As indicated in Fig. 2, the initial state_c is set by Keras to be 0.
Let us step through the 3 time-steps one by one to hand-calculate these values so that we may arrive at the final values shown above at the end of the third time-step. Fig. 4 starts with the first input of 1, and flows this input through the memory cell to obtain the output 1 as expected. After the first time-step, the recurrent kernel becomes 1 and the cell stores a value of 1.
For the second time-step, we feed 1 (again) as the input, and hand-calculate 12 as the sequence_out and state_h values as shown in Fig. 5. The state_c value is 6.
After the second time-step, the recurrent kernel becomes 12, and the cell stores a value of 6 (the state_c value). For the third time-step, the next sequence input value is again 1. We hand-calculate 3211 as the sequence_out and state_h values as shown in Fig. 6. The state_c value is 247. These numbers agree with the output of our Keras code in Fig. 3.
The internal memory state resets to 0 after the end of the sequence. So, at the end of the third time-step (our sequence length is 3), the cell looks like the one in Fig. 7.
If the batch size is 1, back-propagation and adjustment of weights happens at this point (that is, at the end of each sequence). If the batch size is N, back-propagation happens after N sequences have been fed to the cell. Note that an output is emitted by the cell after each time-step in the sequence.
Josh and I found this exercise helpful in understanding LSTMs. Hopefully, you gained a similar benefit from this article.