๐Ÿ‘ฉ‍๐Ÿ’ป

[์ฝ”๋“œ ๋ฆฌ๋ทฐ] ๋…ธ๋…„์ธต ๋Œ€ํ™” ๊ฐ์„ฑ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๊ตฌํ˜„ (1) : CNN

geum 2022. 12. 13. 02:05

์ด๋ฒˆ ํ•™๊ธฐ ํ•˜๋‚˜ ์žˆ๋˜ ํ…€ํ”„๋กœ์ ํŠธ๊ฐ€ ๋๋‚œ ํ›„ ์‚ฌ์šฉํ•œ ์ฝ”๋“œ์— ๋Œ€ํ•ด ๋ณต์Šต ๋ชฉ์ ์œผ๋กœ ๊ธ€์„ ์ž‘์„ฑํ•œ๋‹ค. ๋ชจ๋ธ์€ ์ด 3๊ฐœ์ธ๋ฐ CNN, RNN, Transformer ์ˆœ์œผ๋กœ ์ •๋ฆฌํ•  ์˜ˆ์ •์ด๋‹ค.

 

๐Ÿ‘ฉ‍๐Ÿซ ๋ชจ๋ธ ํด๋ž˜์Šค

class CNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_filters, filter_size,  dropout, num_class):
        super(CNN, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.conv1d_layers = nn.ModuleList([nn.Conv1d(in_channels=embed_dim, out_channels=n_filters[i], kernel_size=filter_size[i]) for i in range(len(filter_size))])
        self.fc_layer = nn.Linear(np.sum(n_filters), num_class)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, sentence):
        x_embed = self.embedding(sentence)
        x_embed = x_embed.permute(0, 2, 1)

        x_conv_list = [F.relu(conv1d(x_embed)) for conv1d in self.conv1d_layers]
        x_pool_list = [F.max_pool1d(x_conv, kernel_size=x_conv.shape[2]) for x_conv in x_conv_list]
        
        x_fc_layer = torch.cat([x_pool.squeeze(dim=2) for x_pool in x_pool_list], dim=1)

        logits = self.fc_layer(self.dropout(x_fc_layer))

        return logits

 

๐ŸŽฏ ํŒŒ๋ผ๋ฏธํ„ฐ

โ—ฝ vocab_size: vocab ํฌ๊ธฐ. vocab์— ๋“ค์–ด์žˆ๋Š” ๋‹จ์–ด์˜ ๊ฐœ์ˆ˜

โ—ฝ embed_dim: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ์ฐจ์›

โ—ฝ n_filters: ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์„ ์œ„ํ•œ ํ•„ํ„ฐ ๊ฐœ์ˆ˜

โ—ฝ filter_size: ํ•œ ๋ฒˆ์— ๋ณผ ๊ธ€์ž ์ˆ˜

โ—ฝ num_class: ํƒ€๊ฒŸ์ด ๋˜๋Š” ๋ ˆ์ด๋ธ” ์ˆ˜

 

โณ ์ž‘๋™ ๋ฐฉ์‹

1. __init__

1) nn.Embedding(vocab_size, embed_dim)

โ—ฝ vocab_size ํ–‰×embed_dim ์—ด ํฌ๊ธฐ์˜ ๋ฃฉ์—… ํ…Œ์ด๋ธ” ์ƒ์„ฑ

 

2) nn.ModuleList(~)

โ—ฝ ๋ชจ๋“ˆ์„ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ €์žฅํ•˜๋Š” nn.ModuleList()๋ฅผ ์ด์šฉํ•ด nn.Conv1d ๋ ˆ์ด์–ด๋ฅผ filter_size ๊ธธ์ด๋งŒํผ ์—ฐ๊ฒฐํ•œ๋‹ค. ๋‚˜๋Š” filter_size = [20, 20, 20]์œผ๋กœ ์‚ฌ์šฉํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— 3๊ฐœ์˜ Conv1d ๋ ˆ์ด์–ด๊ฐ€ ์—ฐ๊ฒฐ๋œ๋‹ค.

 

3) nn.Conv1d(in_channels=embed_dim, out_channels=n_filters[i], kernel_size=filter_size[i])

โ—ฝ (batch, in_channels, kernel_size)๋กœ ๊ตฌ์„ฑ๋œ conv1d ๋ ˆ์ด์–ด์™€ ์ž…๋ ฅ ๋ฒกํ„ฐ์˜ ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์ด ๋๋‚ฌ์„ ๋•Œ ์ถœ๋ ฅ ์ฐจ์›์€ (batch, ํ•„ํ„ฐ ์ˆ˜, n_out)์ด๋‹ค.

์›Œํ„ฐ๋งˆํฌ ์ฐ๊ธฐ ๊ท€์ฐฎ์•„์„œ ํŒจ์Šค!

โญ n_out ๊ณ„์‚ฐ ๊ณผ์ •์€ PyTorch ๊ณต์‹ ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ–ˆ๋‹ค.

 

4) nn.Linear(np.sum(n_filters), num_class)

โ—ฝ np.sum(n_filters) ์ฐจ์›์„ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„์„œ num_class ์ฐจ์› ๋ฒกํ„ฐ๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค. nn.Linear ๋ ˆ์ด์–ด์— ์ž…๋ ฅ์œผ๋กœ ๋“ค์–ด์˜ค๋ ค๋ฉด nn.Linear(~) 0๋ฒˆ ์ธ์ž์™€ ์ฐจ์›์ด ๊ฐ™์•„์•ผ ํ•œ๋‹ค. 

 

2. forward

1) sentence

โ—ฝ ๋ชจ๋ธ ๋Œ๋ฆด ๋•Œ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋Š” 16, ๋ฌธ์žฅ์˜ max_length๋Š” 152๋กœ ์ง€์ •ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— sentece์˜ ํฌ๊ธฐ๋Š” torch.Size([16, 152])๊ฐ€ ๋œ๋‹ค.

2) x_embed = self.embedding(sentence)

โ—ฝ x_embed๋Š” sentence๊ฐ€ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ํฌ๊ธฐ 100์ธ ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผํ•œ ํ›„ ์ƒ์„ฑ๋˜๋Š” ๋ฒกํ„ฐ์ด๋ฏ€๋กœ x_embed์˜ ํฌ๊ธฐ๋Š” torch.Size([16, 152, 100])์ด ๋œ๋‹ค. → ์ฐจ์›์ด ํ•˜๋‚˜ ๋” ์ถ”๊ฐ€๋จ

3) x_embed = x_embed.permute(0, 2, 1)

โ—ฝ permute ํ•จ์ˆ˜๋Š” ํ˜• ๋ณ€ํ™˜ ํ•จ์ˆ˜์ธ๋ฐ (0, 1, 2)์˜ ๊ฐ’๋“ค์„ ์ ํžŒ ์ˆœ์„œ๋Œ€๋กœ ๋ณ€๊ฒฝํ•œ๋‹ค. torch.Size([16, 152, 100])์—์„œ 0=16, 1=152, 2=100์ด๋ฏ€๋กœ permute ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•œ x_embed์˜ ํฌ๊ธฐ๋Š” torch.Size([16, 100, 152])๊ฐ€ ๋œ๋‹ค.

 

4) x_conv_list

โ—ฝ์ž…๋ ฅ x_embed๋ฅผ 3๊ฐœ์˜ conv1d ๋ ˆ์ด์–ด ๊ฐ๊ฐ์— ํ†ต๊ณผ์‹œ์ผœ์ฃผ๊ณ  ReLU ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•œ๋‹ค. CNN์—์„œ๋Š” ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋กœ ReLU๋ฅผ ์ฃผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค๊ณ  ํ•ด์„œ ์™œ Softmax๊ฐ€ ์•„๋‹ˆ์ง€? ๋ผ๋Š” ์ƒ๊ฐ์ด ๋“ค์—ˆ๋Š”๋ฐ, ์•„๋ž˜ ๋งํฌ๊ฐ€ ์ข‹์€ ๋‹ต์•ˆ์ด ๋˜์–ด ์ฃผ์—ˆ๋‹ค.

 

๐Ÿ”Ž https://datascience.stackexchange.com/questions/115276/cnn-model-why-is-relu-used-in-conv1d-layer-and-in-the-first-dense-layer

 

5) x_pool_list

โ—ฝ ์ด๋ฏธ์ง€์— CNN์„ ์ ์šฉํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ MaxPooling1D๋ฅผ ์ ์šฉํ•˜์—ฌ ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์œผ๋กœ ์–ป์€ ๊ฒฐ๊ณผ(x_conv_list๋ฅผ ์ด๋ฃจ๊ณ  ์žˆ๋Š” ๊ฐ๊ฐ์˜ conv layer)์—์„œ ๊ฐ€์žฅ ํฐ ๊ฐ’์„ ๋ฝ‘๋Š”๋‹ค.

 

์ถœ์ฒ˜: https://wikidocs.net/80437

 

6) x_fc_layer

โ—ฝ torch.cat(INPUT, dim=1)์€ ์ž…๋ ฅ ๋ฒกํ„ฐ๋ฅผ ๋‘ ๋ฒˆ์งธ ์ฐจ์› ๋ฐฉํ–ฅ์œผ๋กœ ํ•ฉ์น˜๋ผ๋Š” ์˜๋ฏธ์ด๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด 2*2 ํฌ๊ธฐ์˜ ๋‘ ๋ฒกํ„ฐ๋ฅผ dim=1๋กœ ์ฃผ๊ณ  ํ•ฉ์นœ๋‹ค๋ฉด 2*4๊ฐ€ ๋˜๊ณ  dim=0์œผ๋กœ ์ฃผ๊ณ  ํ•ฉ์นœ๋‹ค๋ฉด 4*2๊ฐ€ ๋˜๋Š” ์‹์ด๋‹ค. 

 

โ—ฝ torch.cat ๋‚ด๋ถ€์˜ '[x_pool.squeeze(dim=2) for x_pool in x_pool_list]'๋Š” x_pool_list๋ฅผ ์ด๋ฃจ๊ณ  ์žˆ๋Š” x_pool ๋ฒกํ„ฐ์—์„œ ํฌ๊ธฐ๊ฐ€ 1์ด๊ณ  dim=2 ์œ„์น˜์— ์žˆ๋Š” ์ฐจ์›๋งŒ ์ œ๊ฑฐํ•œ๋‹ค๋Š” ์˜๋ฏธ์ด๋‹ค.