Note

LSTM 이해하기 본문

Deep Learning

LSTM 이해하기

알 수 없는 사용자 2022. 5. 22. 22:30
728x90
lstm = LSTM(3, return_sequences=False, return_state=True)
hidden_state, last_state, last_cell_state = lstm(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))
print('last cell state : {}, shape: {}'.format(last_cell_state, last_cell_state.shape))
hidden state : [[-0.00263056  0.20051427 -0.22501363]], shape: (1, 3)
last hidden state : [[-0.00263056  0.20051427 -0.22501363]], shape: (1, 3)
last cell state : [[-0.04346419  0.44769213 -0.2644241 ]], shape: (1, 3)

LSTM이 SimpleRNN과 다른 점은 return_state = True인 경우에는 마지막 시점의 은닉 상태뿐만 아니라 셀 상태까지 보여준다.

lstm = LSTM(3, return_sequences=True, return_state=True)
hidden_states, last_hidden_state, last_cell_state = lstm(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('last hidden state : {}, shape: {}'.format(last_hidden_state, last_hidden_state.shape))
print('last cell state : {}, shape: {}'.format(last_cell_state, last_cell_state.shape))
hidden states : [[[ 0.1383949   0.01107763 -0.00315794]
  [ 0.0859854   0.03685492 -0.01836833]
  [-0.02512104  0.12305924 -0.0891041 ]
  [-0.27381724  0.05733536 -0.04240693]]], shape: (1, 4, 3)
last hidden state : [[-0.27381724  0.05733536 -0.04240693]], shape: (1, 3)
last cell state : [[-0.39230722  1.5474017  -0.6344505 ]], shape: (1, 3)

 

Comments