๐Ÿ‘ฉ‍๐Ÿ’ป

[์ฝ”๋“œ ๋ฆฌ๋ทฐ] ๋…ธ๋…„์ธต ๋Œ€ํ™” ๊ฐ์„ฑ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๊ตฌํ˜„ (2) : RNN

geum 2022. 12. 21. 15:31

๊ฐ์„ฑ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๊ตฌํ˜„ ์‹œ๋ฆฌ์ฆˆ (1) | CNN

 

๐Ÿ‘ฉ‍๐Ÿซ ๋ชจ๋ธ ํด๋ž˜์Šค

class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, n_layers, dropout, num_class, device):
        super(RNN, self).__init__()

        self.device = device        
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.gru = nn.GRU(embed_dim, self.hidden_dim, self.n_layers, batch_first=True)

        self.linear_layer = nn.Linear(self.hidden_dim, num_class)

    def forward(self, sentence):
        x = self.embed(sentence)

        init_hidden = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(self.device)

        output, _ = self.gru(x, init_hidden)

        t_hidden = output[:, -1, :]

        self.dropout(t_hidden)

        logits = self.linear_layer(t_hidden)

        return logits

 

๐ŸŽฏ ํŒŒ๋ผ๋ฏธํ„ฐ

โ€ป ๋ชจ๋ธ ๊ตฌ์กฐ์™€ ์ง์ ‘์ ์ธ ์—ฐ๊ด€์ด ์žˆ๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ ์ •๋ฆฌ

โ—ฝ vocab_size: vocab ํฌ๊ธฐ. vocab์— ๋“ค์–ด์žˆ๋Š” ๋‹จ์–ด์˜ ๊ฐœ์ˆ˜

โ—ฝ embed_dim: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ์ฐจ์›

โ—ฝ hidden_dim: hidden state h์˜ feature ์ˆ˜

โ—ฝ n_layers: recurrent layer์˜ ์ˆ˜

โ—ฝ num_class: ํƒ€๊ฒŸ์ด ๋˜๋Š” ๋ ˆ์ด๋ธ” ์ˆ˜

 

 

โณ ์ž‘๋™ ๋ฐฉ์‹

1. __init__

1) nn.GRU(input_size, hidden_size, num_layers, batch_first=True)

โ—ฝํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋ฐ›์•„์„œ GRU ๊ตฌ์กฐ๋ฅผ ์…‹ํŒ…ํ•˜๋Š” ๋Š๋‚Œ์œผ๋กœ ์ดํ•ดํ–ˆ๋‹ค. batch_first๋Š” default๊ฐ€ False์ธ๋ฐ batch_first=False์ด๋ฉด (์‹œํ€€์Šค ๊ธธ์ด, ๋ฐฐ์น˜ ํฌ๊ธฐ, ์ž…๋ ฅ ์‚ฌ์ด์ฆˆ) ์ด ํ˜•ํƒœ๋กœ GRU ์…€์— ์ž…๋ ฅ์„ ๋ฐ›๋Š”๋‹ค. batch_first=True๋Š” (๋ฐฐ์น˜ ํฌ๊ธฐ, ์‹œํ€€์Šค ๊ธธ์ด, ์ž…๋ ฅ ์‚ฌ์ด์ฆˆ) ํ˜•ํƒœ๋กœ ์ž…๋ ฅ์„ ๋ฐ›๋Š”๋‹ค.

 

2. forward

1) init_hidden = torch.zeros(~)

โ—ฝ ์ดˆ๊ธฐ hidden state์ด๊ธฐ ๋•Œ๋ฌธ์— ๊ฐ’์€ ๋ชจ๋‘ 0์ด๊ณ  ํฌ๊ธฐ๋งŒ ๋งž์ถฐ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ์—ญํ• ์ด๋‹ค.

 

 

2) output, _ = self.gru(x, init_hidden)

โ—ฝ ์ž…๋ ฅ x์™€ hidden_state๋ฅผ ๊ฐ™์ด ๋„ฃ์–ด์ฃผ๋ฉด GRU ์…€์„ ๊ฑฐ์น˜๊ณ  ๋‚œ ํ›„์˜ output๊ณผ hidden state๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค. GRU ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ ์ฝ”๋“œ๋ฅผ ์ข€ ์ฐพ์•„ ๋ณด๋‹ˆ๊นŒ hidden state๋Š” ์•ˆ ์“ฐ๊ณ  output ๊ฐ’๋งŒ์œผ๋กœ ์ฒ˜๋ฆฌ๋ฅผ ํ•ด๋„ ์ƒ๊ด€ ์—†๋Š” ๊ฒƒ ๊ฐ™์€๋ฐ, ์ด ๋ถ€๋ถ„์— ๋Œ€ํ•ด์„œ๋Š” ์ถ”๊ฐ€์ ์ธ ๊ณต๋ถ€๊ฐ€ ํ•„์š”ํ•  ๋“ฏ ํ•˜๋‹ค.

 

โ‘  ์ž…๋ ฅ x ์ฐจ์›: torch.Size([16(๋ฐฐ์น˜ ํฌ๊ธฐ), 152(์‹œํ€€์Šค ์ตœ๋Œ€ ๊ธธ์ด), 100(์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ์ฐจ์›)])

โ‘ก ์ถœ๋ ฅ output ์ฐจ์›: torch.Size([16, 152, 150(hidden ๋ ˆ์ด์–ด ์ฐจ์›])

 

3) t_hidden = output[:, -1, :]

โ—ฝ output ๋ฒกํ„ฐ์—์„œ 0๋ฒˆ์งธ ์ฐจ์›(ํ–‰)×2๋ฒˆ์งธ ์ฐจ์›(์—ด) ํฌ๊ธฐ๋กœ ๋ฒกํ„ฐ ํ˜•ํƒœ๋ฅผ ๋ฐ”๊พผ๋‹ค.

 

โ‘ข t_hidden ์ฐจ์›: torch.Size([16, 150])

 

4) logits = self.linear_layer(t_hidden)

โ—ฝ 16*150 ํฌ๊ธฐ์ธ t_hidden ๋ฒกํ„ฐ์™€ 150*6 ํฌ๊ธฐ์˜ Linear ๋ ˆ์ด์–ด ๊ฐ„ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

 

โ‘ฃ ์ตœ์ข… logits ์ฐจ์›: torch.Size([16, 6])