์ด๋ฒ ํ๊ธฐ ํ๋ ์๋ ํ ํ๋ก์ ํธ๊ฐ ๋๋ ํ ์ฌ์ฉํ ์ฝ๋์ ๋ํด ๋ณต์ต ๋ชฉ์ ์ผ๋ก ๊ธ์ ์์ฑํ๋ค. ๋ชจ๋ธ์ ์ด 3๊ฐ์ธ๋ฐ CNN, RNN, Transformer ์์ผ๋ก ์ ๋ฆฌํ ์์ ์ด๋ค.
๐ฉ๐ซ ๋ชจ๋ธ ํด๋์ค
class CNN(nn.Module):
def __init__(self, vocab_size, embed_dim, n_filters, filter_size, dropout, num_class):
super(CNN, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.conv1d_layers = nn.ModuleList([nn.Conv1d(in_channels=embed_dim, out_channels=n_filters[i], kernel_size=filter_size[i]) for i in range(len(filter_size))])
self.fc_layer = nn.Linear(np.sum(n_filters), num_class)
self.dropout = nn.Dropout(p=dropout)
def forward(self, sentence):
x_embed = self.embedding(sentence)
x_embed = x_embed.permute(0, 2, 1)
x_conv_list = [F.relu(conv1d(x_embed)) for conv1d in self.conv1d_layers]
x_pool_list = [F.max_pool1d(x_conv, kernel_size=x_conv.shape[2]) for x_conv in x_conv_list]
x_fc_layer = torch.cat([x_pool.squeeze(dim=2) for x_pool in x_pool_list], dim=1)
logits = self.fc_layer(self.dropout(x_fc_layer))
return logits
๐ฏ ํ๋ผ๋ฏธํฐ
โฝ vocab_size: vocab ํฌ๊ธฐ. vocab์ ๋ค์ด์๋ ๋จ์ด์ ๊ฐ์
โฝ embed_dim: ์๋ฒ ๋ฉ ๋ฒกํฐ ์ฐจ์
โฝ n_filters: ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ ์ํ ํํฐ ๊ฐ์
โฝ filter_size: ํ ๋ฒ์ ๋ณผ ๊ธ์ ์
โฝ num_class: ํ๊ฒ์ด ๋๋ ๋ ์ด๋ธ ์
โณ ์๋ ๋ฐฉ์
1. __init__
1) nn.Embedding(vocab_size, embed_dim)
โฝ vocab_size ํ×embed_dim ์ด ํฌ๊ธฐ์ ๋ฃฉ์ ํ ์ด๋ธ ์์ฑ
2) nn.ModuleList(~)
โฝ ๋ชจ๋์ ๋ฆฌ์คํธ ํํ๋ก ์ ์ฅํ๋ nn.ModuleList()๋ฅผ ์ด์ฉํด nn.Conv1d ๋ ์ด์ด๋ฅผ filter_size ๊ธธ์ด๋งํผ ์ฐ๊ฒฐํ๋ค. ๋๋ filter_size = [20, 20, 20]์ผ๋ก ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ 3๊ฐ์ Conv1d ๋ ์ด์ด๊ฐ ์ฐ๊ฒฐ๋๋ค.
3) nn.Conv1d(in_channels=embed_dim, out_channels=n_filters[i], kernel_size=filter_size[i])
โฝ (batch, in_channels, kernel_size)๋ก ๊ตฌ์ฑ๋ conv1d ๋ ์ด์ด์ ์ ๋ ฅ ๋ฒกํฐ์ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ด ๋๋ฌ์ ๋ ์ถ๋ ฅ ์ฐจ์์ (batch, ํํฐ ์, n_out)์ด๋ค.
โญ n_out ๊ณ์ฐ ๊ณผ์ ์ PyTorch ๊ณต์ ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ๋ค.
4) nn.Linear(np.sum(n_filters), num_class)
โฝ np.sum(n_filters) ์ฐจ์์ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ num_class ์ฐจ์ ๋ฒกํฐ๋ฅผ ์ถ๋ ฅํ๋ค. nn.Linear ๋ ์ด์ด์ ์ ๋ ฅ์ผ๋ก ๋ค์ด์ค๋ ค๋ฉด nn.Linear(~) 0๋ฒ ์ธ์์ ์ฐจ์์ด ๊ฐ์์ผ ํ๋ค.
2. forward
1) sentence
โฝ ๋ชจ๋ธ ๋๋ฆด ๋ ๋ฐฐ์น ์ฌ์ด์ฆ๋ 16, ๋ฌธ์ฅ์ max_length๋ 152๋ก ์ง์ ํ๊ธฐ ๋๋ฌธ์ sentece์ ํฌ๊ธฐ๋ torch.Size([16, 152])๊ฐ ๋๋ค.
2) x_embed = self.embedding(sentence)
โฝ x_embed๋ sentence๊ฐ ์๋ฒ ๋ฉ ๋ฒกํฐ ํฌ๊ธฐ 100์ธ ์๋ฒ ๋ฉ ๋ ์ด์ด๋ฅผ ํต๊ณผํ ํ ์์ฑ๋๋ ๋ฒกํฐ์ด๋ฏ๋ก x_embed์ ํฌ๊ธฐ๋ torch.Size([16, 152, 100])์ด ๋๋ค. → ์ฐจ์์ด ํ๋ ๋ ์ถ๊ฐ๋จ
3) x_embed = x_embed.permute(0, 2, 1)
โฝ permute ํจ์๋ ํ ๋ณํ ํจ์์ธ๋ฐ (0, 1, 2)์ ๊ฐ๋ค์ ์ ํ ์์๋๋ก ๋ณ๊ฒฝํ๋ค. torch.Size([16, 152, 100])์์ 0=16, 1=152, 2=100์ด๋ฏ๋ก permute ํจ์๋ฅผ ์ ์ฉํ x_embed์ ํฌ๊ธฐ๋ torch.Size([16, 100, 152])๊ฐ ๋๋ค.
4) x_conv_list
โฝ์ ๋ ฅ x_embed๋ฅผ 3๊ฐ์ conv1d ๋ ์ด์ด ๊ฐ๊ฐ์ ํต๊ณผ์์ผ์ฃผ๊ณ ReLU ํจ์๋ฅผ ์ ์ฉํ๋ค. CNN์์๋ ํ์ฑํ ํจ์๋ก ReLU๋ฅผ ์ฃผ๋ก ์ฌ์ฉํ๋ค๊ณ ํด์ ์ Softmax๊ฐ ์๋์ง? ๋ผ๋ ์๊ฐ์ด ๋ค์๋๋ฐ, ์๋ ๋งํฌ๊ฐ ์ข์ ๋ต์์ด ๋์ด ์ฃผ์๋ค.
5) x_pool_list
โฝ ์ด๋ฏธ์ง์ CNN์ ์ ์ฉํ๋ ๊ฒ์ฒ๋ผ MaxPooling1D๋ฅผ ์ ์ฉํ์ฌ ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ผ๋ก ์ป์ ๊ฒฐ๊ณผ(x_conv_list๋ฅผ ์ด๋ฃจ๊ณ ์๋ ๊ฐ๊ฐ์ conv layer)์์ ๊ฐ์ฅ ํฐ ๊ฐ์ ๋ฝ๋๋ค.
6) x_fc_layer
โฝ torch.cat(INPUT, dim=1)์ ์ ๋ ฅ ๋ฒกํฐ๋ฅผ ๋ ๋ฒ์งธ ์ฐจ์ ๋ฐฉํฅ์ผ๋ก ํฉ์น๋ผ๋ ์๋ฏธ์ด๋ค. ์๋ฅผ ๋ค์ด 2*2 ํฌ๊ธฐ์ ๋ ๋ฒกํฐ๋ฅผ dim=1๋ก ์ฃผ๊ณ ํฉ์น๋ค๋ฉด 2*4๊ฐ ๋๊ณ dim=0์ผ๋ก ์ฃผ๊ณ ํฉ์น๋ค๋ฉด 4*2๊ฐ ๋๋ ์์ด๋ค.
โฝ torch.cat ๋ด๋ถ์ '[x_pool.squeeze(dim=2) for x_pool in x_pool_list]'๋ x_pool_list๋ฅผ ์ด๋ฃจ๊ณ ์๋ x_pool ๋ฒกํฐ์์ ํฌ๊ธฐ๊ฐ 1์ด๊ณ dim=2 ์์น์ ์๋ ์ฐจ์๋ง ์ ๊ฑฐํ๋ค๋ ์๋ฏธ์ด๋ค.
'๐ฉโ๐ป' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[์ฝ๋ ๋ฆฌ๋ทฐ] ๋ ธ๋ ์ธต ๋ํ ๊ฐ์ฑ ๋ถ๋ฅ ๋ชจ๋ธ ๊ตฌํ (3): Transformer โ (0) | 2022.12.27 |
---|---|
[์ฝ๋ ๋ฆฌ๋ทฐ] ๋ ธ๋ ์ธต ๋ํ ๊ฐ์ฑ ๋ถ๋ฅ ๋ชจ๋ธ ๊ตฌํ (2) : RNN (0) | 2022.12.21 |
[ART] attack_adversarial_patch_TensorFlowV2.ipynb ์ฝ๋ ๋ถ์ (0) | 2022.01.19 |
[ART] attack_defence_imagenet.ipynb ์ฝ๋ ์ค์ต (0) | 2022.01.18 |
[ART] adversarial_training_mnist.ipynb ์ฝ๋ ๋ถ์ (0) | 2022.01.12 |