Note
Simple RNN 이해하기 본문
728x90
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNN, LSTM, Bidirectional
train_X = [[0.1, 4.2, 1.5, 1.1, 2.8], [1.0, 3.1, 2.5, 0.7, 1.1], [0.3, 2.1, 1.5, 2.1, 0.1], [2.2, 1.4, 0.5, 0.9, 1.1]]
print(np.shape(train_X))
(4, 5)
단어 벡터의 차원 : 5
문장의 길이 : 4
2차원 텐서
train_X = [[[0.1, 4.2, 1.5, 1.1, 2.8], [1.0, 3.1, 2.5, 0.7, 1.1], [0.3, 2.1, 1.5, 2.1, 0.1], [2.2, 1.4, 0.5, 0.9, 1.1]]]
train_X = np.array(train_X, dtype=np.float32)
print(train_X.shape)
(1, 4, 5) # (batch_size, timesteps, input_dim)
RNN은 3차원 텐서를 입력 값으로 사용하기 때문에 배치 크기 1을 추가한다.
rnn = SimpleRNN(3) # 다른 표현 SimpleRNN(3, return_sequences=False, return_state=False)
hidden_state = rnn(train_X)
print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))
hidden state : [[-0.866719 0.95010996 -0.99262357]], shape: (1, 3)
은닉 상태가 3인 RNN, 다른 옵션 return_sequences가 False 이기 때문에 마지막 시점 은닉 상태 값만 출력.
rnn = SimpleRNN(3, return_sequences=True)
hidden_states = rnn(train_X)
print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
hidden states : [[[ 0.92948604 -0.9985648 0.98355013]
[ 0.89172053 -0.9984244 0.191779 ]
[ 0.6681082 -0.96070355 0.6493537 ]
[ 0.95280755 -0.98054564 0.7224146 ]]], shape: (1, 4, 3)
return_sequences가 True인 경우 모든 시점 은닉 상태 값 출력.
rnn = SimpleRNN(3, return_sequences=True, return_state=True)
hidden_states, last_state = rnn(train_X)
print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))
hidden states : [[[ 0.29839835 -0.99608386 0.2994854 ]
[ 0.9160876 0.01154806 0.86181474]
[-0.20252597 -0.9270214 0.9696659 ]
[-0.5144398 -0.5037417 0.96605766]]], shape: (1, 4, 3)
last hidden state : [[-0.5144398 -0.5037417 0.96605766]], shape: (1, 3)
return_state가 True일 경우에는 return_sequences의 옵션과 상관없이 마지막 시점 은닉 상태를 출력한다.
'Deep Learning' 카테고리의 다른 글
LSTM 이해하기 (0) | 2022.05.22 |
---|---|
게이트 순환 유닛(Gated Recurrent Unit, GRU) (0) | 2022.05.19 |
장단기 메모리(Long Short-Term Memory, LSTM) (0) | 2022.05.18 |
순환 신경망(Recurrent Neural Network, RNN) (0) | 2022.05.17 |
Comments