Long Short-term Memory (LSTM)
20 May 2023 #nlp
Recurrent neural network (RNN)의 Long-term dependency 문제를 해결하고자 만들어진 프레임워크이다.
핵심적인 아이디어는 이전 시점의 state 정보를 이후 state에 얼마나 반영할지를 결정하는 계산을 추가해주는 것이다. 이것을 위해 forget gate, input gate, output gate의 3가지 Gate와 memory cell이 추가 되었다.
LSTM의 계산 방식과 비교하기 위해 RNN의 계산식을 되짚어보면 다음과 같다.
\[\begin{align*} h_t&=\tau(Wh^{t-1}+Ux^t) \\ \hat{y}_t&=softmax(Vh^t) \\ \end{align*}\]이전 시점($t-1$)의 hidden state와 현재 시점($t$)의 input(둘 다 weighted)을 더하여 tanh를 통과시키면 현재 시점의 hidden state가 된다. 이것에 softmax를 취하면 output이 된다.
LSTM 계산식
LSTM은 RNN의 방식에 residual connection 구조에서 착안한 memory 기능을 더하여 long-term dependency 문제를 해결하고자 했다. Memory 기능을 위해 추가된 계산들은 다음과 같다.
Forget gate
Forget gate $f$는 이전 시점의 정보를 얼마나 잊을지 결정하는 gate이다.
\[f_t=\sigma(W_fh_{t-1}+U_fx_t)\]이전 시점의 hidden state와 현재 시점의 input을 더한 뒤 sigmoid를 취한다. 이것이 이전 시점의 memory cell state에 곱해진다. sigmoid의 특성에 의해 1에 가까울수록 이전 정보가 이후 많이 반영된다.
Input gate
Input gate $i$는 현재 시점의 input을 다음 시점에 얼마나 반영할지 결정하는 gate이다. 여기서 candidate $\hat(c)$라는 개념이 등장하는데, candidate는 이전 시점의 hidden state와 현재 시점의 input을 고려했을 때 현재의 정보가 어떠한지를 나타내는 cell state의 후보 격인 값이다. 계산 방식이 RNN의 hidden state와 동일하다. input gate 값과 candidate를 곱해 현재의 정보 상 input이 얼마나 반영되면 좋은지를 구하고, 이것을 최종적으로 cell state에 더한다.
\[\begin{align} i_t&=\sigma(W_{in}h_{_t-1}+U_{in}x_{t})\\ \hat{C}_t&=\tau(W_{c}h_{t-1}+U_{c}x_t) \end{align}\]Memory cell
Memory cell (cell state)은 세 가지 gate와 함께 LSTM의 구현 목적을 위해 추가된 개념이다. 현재 시점의 cell state는 이전 시점의 cell state 및 현재 시점의 forget gate와 현재 시점의 input gate 및 candidate로 계산한다.
\[C_t=f_t*C_{t-1}+i_t*\hat{C}_t\]$*$는 pointwise operation
이전 정보인 cell state와 현재 input을 얼마나 반영할지가 합해져 현재 시점의 cell state가 구해진다.
Output gate
Output gate는 memory cell을 현재 시점의 hidden state에 얼마나 반영할지 결정한다.
\[\begin{align} o_t&=\sigma(W_oh_{t-1}+U_ox_t) \\ h_t&=o_t\tau(C_t) \\ &=o_t\tau(f_t*C_{t-1}+i_t*\hat{C}_t) \\ \end{align}\]현재 시점의 hidden state는 이전 시점의 정보와 현재 시점의 input이 반영된 현재 시점의 cell state와 output gate의 결과값과 곱해져 최종 결정된다.
Output
최종 출력 $\hat{y}_t$은 RNN과 같이 계산된다.
\[\hat{y}_t=softmax(Vh_t)\]LSTM의 한계
LSTM은 cell state 도입을 통해 gradient vanishing 문제를 해결하고자 하였다. 하지만 RNN 구조를 기반으로 하고 있는 한 이 문제를 완벽하게 해결하기에 한계가 있다. 오히려 gate를 여러 개 사용하여 계산량이 증가하는 문제가 있다.