๊ฐ์ฑ ๋ถ๋ฅ ๋ชจ๋ธ ๊ตฌํ ์๋ฆฌ์ฆ (1) | CNN
๊ฐ์ฑ ๋ถ๋ฅ ๋ชจ๋ธ ๊ตฌํ ์๋ฆฌ์ฆ (2) | RNN
Transformer ๋ถ๋ฅ ๋ชจ๋ธ์ ๋จ์ผ ํ์ผ์ด ์๋๋ผ์ ํ๋์ฉ ๋ถ์ํ๋ฉด ๊ธ์ด 3๊ฐ๋ 4๊ฐ ์ ๋ ๋์ฌ ๊ฒ ๊ฐ๋ค.
๐ฉ๐ซ ๋ชจ๋ธ ํด๋์ค
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from .encoder import Encoder, EncoderLayer
from .sublayers import *
attn = MultiHeadAttention(8, 152)
ff = PositionwiseFeedForward(152, 1024, 0.5)
pe = PositionalEncoding(152, 0.5)
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model, n_layer, num_class):
super(Transformer, self).__init__()
self.encoder = Encoder(EncoderLayer(d_model, deepcopy(attn), deepcopy(ff)), n_layer)
self.src_embed = nn.Sequential(Embeddings(d_model, vocab_size), deepcopy(pe))
self.linear = nn.Linear(d_model, num_class)
def forward(self, x):
x = self.src_embed(x)
x = self.encoder(x)
x = x[:, -1, :]
x = self.linear(x)
logits = F.softmax(x, dim=-1)
return logits
์ฒ์์๋ attn, ff, pe ๋ณ์๋ ํ๋ผ๋ฏธํฐ๋ก ๋๊ธฐ๋ ค๊ณ ํ๋๋ฐ ๋ญ๊ฐ ์ ์๋ผ์ ํด๋์ค๋ ํ ํ์ผ์ ๊ฐ์ด ๋๋ค. ์ข์ ๋ฐฉ๋ฒ์ ์๋๋ผ๊ณ ์๊ฐํ๋ค.
๐ฏ ํ๋ผ๋ฏธํฐ
โฝ d_model: ์๋ฒ ๋ฉ ๋ฒกํฐ ์ฐจ์
โฝ n_layer: ์ธ์ฝ๋ ๋ ์ด์ด ์(์๋ณธ ๋ ผ๋ฌธ์์๋ 6๊ฐ์ ์ธ์ฝ๋ ๋ ์ด์ด๋ฅผ ์ด์ด ๋ถ์ฌ์ ํ๋์ ์ธ์ฝ๋๋ก ์ฌ์ฉ)
โณ ์๋ ๋ฐฉ์
1. __init__
1) Encoder(EncoderLayer(d_model, deepcopy(attn), deepcopy(ff)), n_layer)
โป ์ด ํํธ๋ Encoder ํด๋์ค, EncoderLayer ํด๋์ค์ ๋ํ ๊ธ์ด ์์ฑ๋๋ฉด ๋งํฌ๋ฅผ ์ถ๊ฐํด๋์ ์์ ์ด๋ค.
โฝ ๋งค์ฐ ๊ฐ๋จํ๊ฒ ์ค๋ช ํ๋ฉด (d_model, deepcopy(attn), deepcopy(ff))๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ๋ EncoderLayer๋ฅผ n_layer๊ฐ ์ฌ์ฉํ๋ค๋ ์๋ฏธ์ด๋ค.
2) nn.Sequential(Embeddings(d_model, vocab_size)), deepcopy(pe))
โฝ ์๋ฒ ๋ฉ์ธต๊ณผ positional encoding ์ ๋ณด๋ฅผ ํจ๊ป ์ธ์ฝ๋์ ์ ๋ ฅ์ผ๋ก ๋ฃ์ด์ฃผ๊ธฐ ์ํด nn.Sequential๋ก ์ฐ๊ฒฐํ๋ค. Positional encoding ์ธต ์ฐจ์์ d_model๊ณผ ๋ํด์ ธ์ผ ํ๊ธฐ ๋๋ฌธ์ d_model ์ฐจ์๊ณผ ๋์ผํ๋ค.
2. forward
1) x = self.src_embed(x)
โฝ torch.Size([16(๋ฐฐ์น ํฌ๊ธฐ), 152(๋ฌธ์ฅ ์ต๋ ๊ธธ์ด)])๋ฅผ ๊ฐ๋ ์ ๋ ฅ x๋ ์๋ฒ ๋ฉ ์ธต์ ๊ฑฐ์ณ torch.Size([16, 152, 152])๋ก ์ฐจ์์ด ๋ฐ๋๋ค.
2) x = x[:, -1, :]
โฝ x์ ์ฐจ์์ (16, 152)๋ก ๋ณ๊ฒฝํ๋ค. → torch.Size([16, 152])
3) x = self.linear(x)
โฝ linear ๋ ์ด์ด๋ d_model ์ฐจ์ ๋ฒกํฐ๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ num_class ์๋งํผ ์ถ๋ ฅ ๋ฒกํฐ๋ฅผ ๋ง๋ ๋ค. ์ด๋ ๊ฒ ๋ง๋ค์ด์ง ๋ฒกํฐ์ softmax ํจ์๋ฅผ ์ ์ฉํ๋ฉด ์๋์ ๊ฐ์ด ๋ ์ด๋ธ๋ณ ํ๋ฅ ์ด ๋์ค๊ฒ ๋๊ณ , ๊ทธ ์ค ํ๋ฅ ์ด ์ต๋์ธ ๋ ์ด๋ธ์ด ๋ชจ๋ธ์ ์ต์ข ์์ธก ๊ฒฐ๊ณผ๊ฐ ๋๋ค.
'๐ฉโ๐ป' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[์ฝ๋ ๋ฆฌ๋ทฐ] ๋ ธ๋ ์ธต ๋ํ ๊ฐ์ฑ ๋ถ๋ฅ ๋ชจ๋ธ ๊ตฌํ (2) : RNN (0) | 2022.12.21 |
---|---|
[์ฝ๋ ๋ฆฌ๋ทฐ] ๋ ธ๋ ์ธต ๋ํ ๊ฐ์ฑ ๋ถ๋ฅ ๋ชจ๋ธ ๊ตฌํ (1) : CNN (0) | 2022.12.13 |
[ART] attack_adversarial_patch_TensorFlowV2.ipynb ์ฝ๋ ๋ถ์ (0) | 2022.01.19 |
[ART] attack_defence_imagenet.ipynb ์ฝ๋ ์ค์ต (0) | 2022.01.18 |
[ART] adversarial_training_mnist.ipynb ์ฝ๋ ๋ถ์ (0) | 2022.01.12 |