ํฐ์คํ ๋ฆฌ ๋ทฐ
Phillip Isola
1. Introduction
Image-to-Image Translation์ ์ด๋ฏธ์ง๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ ๋ ๋ค๋ฅธ ์ด๋ฏธ์ง๋ฅผ ์ถ๋ ฅ์ผ๋ก ๋ฐํํ๋ Task๋ฅผ ๋ปํ๋ค.
๋ณธ ๋ ผ๋ฌธ์ Image-to-Image Translation์ ์ ํฉํ cGAN์ ๊ธฐ๋ฐ์ผ๋กํ๋ฉฐ ๋ค์ํ Task์์ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ด๋ ํ๋ ์์ํฌ pix2pix๋ฅผ ๋ค๋ฃฌ๋ค. ์๋ ์ธ์ด ๋ฒ์ญ์ด ๊ฐ๋ฅํ ๊ฒ์ฒ๋ผ ์๋ image-to-image ๋ณํ ๋ํ ์ถฉ๋ถํ ํ์ต ๋ฐ์ดํฐ๊ฐ ์ฃผ์ด์ง๋ค๋ฉด ํ ์ฅ๋ฉด์ ํํ์ ๋ค๋ฅธ ์ฅ๋ฉด์ผ๋ก ๋ณํํ๋ ์์ ์ผ๋ก ์ ์ํ ์ ์๋ค.
DCGAN๊ณผ ๋ค๋ฅธ์ ์ Generator(G)์ input์ด random vector๊ฐ ์๋๋ผ condition input ๋ผ๋ ์ ์ด๋ค.
๋ณธ ๋ ผ๋ฌธ์์ ์ฐ๋ฆฌ์ ๋ชฉํ๋ ์ด๋ฌํ ๋ชจ๋ ๋ฌธ์ ์ ๋ํ ๊ณตํต๋ ํ๋ ์์ํฌ๋ฅผ ๊ฐ๋ฐํ๋ ๊ฒ์ด๋ค. -> ๋ค์ํ ๊ณผ์ ํด๊ฒฐ๊ฐ๋ฅ!
์ด๋ฏธ์ง ์์ธก์์ ์ฌ์ฉ๋๋ CNN์์ L1, L2์ ๊ฐ์ ์ ํด๋ฆฌ๋์ ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ์ฉํ loss๋ ํ๋ฆฟํ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ๋, GAN loss๋ฅผ ํจ๊ป ์ฌ์ฉํ๋ค๋ฉด ์ ๋ช ํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์๊ฒ ๋๋ค.
2. Related work
Structed losses for image modeling, Conditional GANs ๊ด๋ จ ๋ด์ฉ
* ๋ณธ ๋ ผ๋ฌธ์ ์ฐ๊ตฌ๋ ํน์ ๋ชฉ์ ์ด ์๋ ๊ฒ์ด ํน์ง์.
3. Method
GANs๋ ๋๋ค ๋ ธ์ด์ฆ ๋ฒกํฐ ๋ฅผ ์ถ๋ ฅ ์ด๋ฏธ์ง ๋ก ๋งคํํ๋ ๋งคํ (์ ์ํํ๋ ๋ชจ๋ธ์ด๋ค. ๋์กฐ์ ์ผ๋ก ์กฐ๊ฑด๋ถ GANs๋ ์กฐ๊ฑด์ ํด๋นํ๋ ์ด๋ฏธ์ง ์ ๋๋ค ๋ ธ์ด์ฆ ๋ฒกํฐ ๋ฅผ ์ถ๋ ฅ ์ด๋ฏธ์ง ๋ก ๋งคํํ๋ ๋งคํ (์ ํ์ตํด ์ํํ๋ค.
์ ๊ทธ๋ฆผ์์๋ (edge)๋ฅผ ์กฐ๊ฑด์ผ๋ก ๋ฐ์ ์ค์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ณ , ํ๋ณ์ ๋ํ ์ด (edge)์ ์์ฑ ์ด๋ฏธ์ง๋ฅผ ๋ฐ์ ํ๋ณํ๋ฉด์ ๋ชจ๋ธ์ ๋ฐ์ ์ํค๊ฒ ๋๋ค.
3.1 Objective
GAN์ objective loss๋ ๋ค์๊ณผ ๊ฐ๋ค.
Conditional GAN์ ๊ธฐ์กด GAN์์ condition x๋ง ์ถ๊ฐ๋ ๊ฒ.
์๋ cGAN์์๋ latent vector z์ ์ปจ๋์ ๋ฒกํฐ z๋ฅผ ๊ฐํ๋ ๋ฐฉ์์ด์๋๋ฐ,
์ฌ๊ธฐ์ ์ปจ๋์ ๋ฒกํฐ ๋์ ์ x๋ฅผ ์ ๋ ฅ ์์์ผ๋ก ๋ณธ๋ค. : ์์ ์
Pix2Pix๋ {๋ณํํ ์์ x, ๋ณํ ๊ฒฐ๊ณผ ์์ y} ๋ฅผ ์์ผ๋ก ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์, ์ ์๊ณผ ๊ฐ์ ๋ชฉํํจ์๊ฐ ๋์ถ๋จ.
"Context Encoder" ๋ ผ๋ฌธ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์๋ค์ -> loss ํจ์๋ฅผ ์ค๊ณํ ๋ traditional loss((ex: L2 loss))๋ฅผ GAN loss์ ์์ด ์ฌ์ฉํ๋ ๊ฒ์ด ํจ๊ณผ์ ์ด๋ผ๋ ๊ฒ.
์คํ ๊ฒฐ๊ณผ L1 loss์ L2 loss ์ค L1์ด ์ด๋ฏธ์ง์ blurํจ์ด ๋ ํด L1 loss๋ฅผ ์ฌ์ฉํ๊ธฐ๋ก ๊ฒฐ์ .
L1+cCAN์ด ์ ค ํจ๊ณผ์ ์ธ ๊ฒ์ ํ์ธํ ์ ์๋ค.
Random noise Z ์ฌ์ฉ X?
G๊ฐ z๋ฅผ ๋ฌด์ํ๋๋ก ํ์ต์ด ๋๊ธฐ ๋๋ฌธ์, random noise Z๋ฅผ ์ฌ์ฉํ์ง ์๋๋ค.
Z๊ฐ ์๊ธฐ ๋๋ฌธ์ mapping ๊ณผ์ ์์ stochastic((๋ฌด์์))ํ ๊ฒฐ๊ณผ๋ฅผ ๋ด์ง ๋ชปํ๊ณ deterministicํ ๊ฒฐ๊ณผ๋ฅผ ๋ด๋ณด๋ด๋ ๋ฌธ์ ๊ฐ ์์ง๋ง, ๋ ผ๋ฌธ์์๋์ด๋ฅผ future work๋ก์จ ๋จ๊ฒจ๋์๋ค.
๋์ ์กฐ๊ธ์ ๋ฌด์์์ฑ์ด ์๊ธฐ๋๋ก layer์ dropout์ ์ ์ฉํ์๋ค.
3.2 Network architectures
3.2.1 Generator with skips
Generator : U-Net ๊ธฐ๋ฐ
encoder-decoder ๊ตฌ์กฐ์์ ์์ ํฌ๊ธฐ๋ฅผ ์ค์๋ค๊ฐ ๋ค์ ํค์ฐ๋ ๊ณผ์ ์์ detail์ด ์ฌ๋ผ์ง๋ฉด์ ์์์ด blurํด์ง๋ ๋ฌธ์ ๋ฅผ ํผํ๊ธฐ ์ํด, ์ค๋ฅธ์ชฝ์ฒ๋ผ ๋ค์ skip connection์ ๊ฐ์ง๋ค.
U-Net์ด๋ ์ด๋ฆ์ด ๋ถ์ ์ด์ ๋ ์์ ๊ฐ์ด U์ ํํ๋ก ํํํ ์ ์๊ธฐ ๋๋ฌธ.
๋ ์ด์ด๊ฐ ๊น์ด์ง ์๋ก ์ธ๋ฐํ ์ ๋ณด๋ ์ด๋์ ๊ฐ ๋ฐ์ง ์๋ ์ด์ ๋จ์ง ์์. (skip connection ์ญํ )
๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์!
3.2.2 Markovian discriminator (PatchGAN)
Disciminator: PatchGAN ์ฌ์ฉ
L1 Loss์ ํน์ง: Low-frequency ์ฑ๋ถ์ ์ ๊ฒ์ถ ํจ.
* Frequency : ์ด๋ฏธ์ง์์ ํฝ์ ๋ณํ์ ์ ๋
- low-frequency : ์ฌ๋ฌผ์ ๋ํด์๋ ์ฌ๋ฌผ ๋ด์์๋ ์ ๋ณํ๊ฐ ํฌ์ง ์์.
- high-frequency : ์ฌ๋ฌผ ๊ฒฝ๊ณ( : edge)์์๋ ์์ด ๊ธ๊ฒฉํ๊ฒ ๋ณํจ.
์ฆ, L1 loss๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ blurryํ์ง๋ง low-frequency ์ฑ๋ถ๋ค์ ์ ๊ฒ์ถํด๋ด๋ฏ๋ก ์ด๋ ๊ทธ๋๋ก ๋๊ณ Discriminator์์ high-frequency์ ๊ฒ์ถ์ ์ํ ๋ชจ๋ธ๋ง์ ์งํํ๊ฒ ๋๋ค.
=> ์ด๋ฏธ์ง ์ ์ฒด๊ฐ ํ์์๊ณ , local image patch ๋ฅผ ์ฌ์ฉํด high-frequency๋ฅผ ํ๋ณํด๋ ๊ด์ฐฎ์.
PatchGAN:
์ ์ฒด ์์ญ์ด ์๋๋ผ, ํน์ ํฌ๊ธฐ์ patch ๋จ์๋ก ์ง์ง/๊ฐ์ง๋ฅผ ํ๋ณํ๊ณ , ๊ทธ ๊ฒฐ๊ณผ์ ํ๊ท ์ ์ทจํ๋ ๋ฐฉ์.
patch ํฌ๊ธฐ์ ๋ฐ๋ฅธ ์คํ์ ์งํํจ.
70x70 ์ผ ๋, ๊ฐ์ฅ ํจ๊ณผ์ ์ธ ๊ฒฐ๊ณผ๊ฐ ๋์ค๋ ๊ฒ์ ํ์ธ.
=> ์ ์ฒด ์ด๋ฏธ์ง์ ๋ํ ํ๋ณ๋ณด๋ค, ์ ์ ํ ํฌ๊ธฐ์ patch๋ฅผ ์ ํ๊ณ ๊ทธ patch๋ค์ด ๋๋ถ๋ถ ์ง์ง์ธ ์ชฝ์ผ๋ก ๊ฒฐ๊ณผ๊ฐ ๋์ค๋๋ก ํ์ต์ ์งํํ๋ ๊ฒ์ด ํจ์จ์ ์ด๋ค.
3.3 Optimization and inference
ํน์ด์ฌํญ ๋ช ๊ฐ์ง.
- train:
1. G๋ฅผ ํ์ตํ ๋, log(x,G(x,z)) ๋ฅผ ์ต๋ํ ํ์ตํ๋๋ก ํจ.
2. D๋ฅผ ํ์ตํ ๋, loss๋ฅผ 2๋ก ๋๋ ์ ํ์ต. D๋ฅผ G๋ณด๋ค ํ์ต ์๋๋ฅผ ๋๋ฆฌ๊ฒํ๊ธฐ ์ํด์์ด๋ค.
- test:
1. dropout์ ์ด๋ค.
2. batch normalization์ test batch์ statistics๋ฅผ ์ฌ์ฉํจ.
4. Experiments
4.1 Evaluation metrics
๊ฒฐ๊ณผ๋ฅผ ์ ์ฒด์ ์ผ๋ก ํ๊ฐํ๊ธฐ ์ํด 2๊ฐ์ง ์ ๋ต์ ์ฌ์ฉ.
1. Amazon Mechanical Turk(AMT)๋ก ‘real vs fake’์ ๋ํด ์ฌ๋์ด ํ๊ฐํ๋ ์ธ์ํ ์คํธ.
์ฌ๋์ด ์ฃผ์ด์ง ์ด๋ฏธ์ง๋ฅผ ๋ณด๊ณ ์ง์ง์ธ์ง ๊ฐ์ง์ธ์ง ํ๋จํ๋ ์คํ, ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ์ผ๋ง๋ ๋ง์ ์ฌ๋๋ค์๊ฒ ์ง์ง ์ด๋ฏธ์ง์ฒ๋ผ ๋ณด์ฌ ์ฌ๋๋ค์ ์์ผ ์ ์์๋์ง์ ๋ํ ํผ์ผํธ๋ฅผ ํตํด ๋ถ์.
2. ‘FCN-score’๋ก ๊ธฐ์กด์ ๋ถ๋ฅ ๋ชจ๋ธ์ ํฉ์ฑ๋ ์ด๋ฏธ์ง ์์ ๋ฌผ์ฒด๋ฅผ ์ธ์ํ ์ ์์ ์ ๋๋ก ํฉ์ฑ๋ ์ด๋ฏธ์ง๊ฐ ํ์ค์ ์ธ์ง ์ฌ๋ถ๋ฅผ ์ธก์ ํ๋ ๋ฐฉ๋ฒ.
๊ธฐ์กด์ semantic segmentation ๋ชจ๋ธ์ ์ฌ์ฉํด ์์ฑ ์ด๋ฏธ์ง ๋ด์ object๋ค์ ์ผ๋ง๋ ์ ํํ๊ฒ ํด๋์ค ๋ณ๋ก ํฝ์ ์ ๋ํ๋ด๋ ์ง๋ฅผ ์คํ,์์ฑ๋ ์ด๋ฏธ์ง๊ณผ ํ์ค์ ์ด๋ฏธ์ง์ ์ ์ฌํ ์๋ก segmentation ๋ชจ๋ธ ๋ํ ๋ ์ ํํ๊ฒ ์ด๋ฏธ์ง ๋ด์ object ๋ค์ segmentation ํ ์ ์์ ๊ฒ์ด๋ผ๋ ์์ด๋์ด๋ก segmentation ๊ฒฐ๊ณผ์ pixel, class ๋ณ ์ ํ๋ ๋ฑ์ ๋ถ์.
4.4 From PixelGANs to PatchGANs to ImageGANs
์ง๋ -> ํญ๊ณต์ฌ์ง / ํญ๊ณต์ฌ์ง -> ์ง๋
์์์ ๋ํ label ๋ฌ์์ฃผ๋ semantic segmentation ์ํ
=> L1์ด ์ ํ๋๊ฐ ๊ฐ์ฅ ๋์ ๋ฐฉ๋ฒ
5. Conclusion
๋ณธ ๋ ผ๋ฌธ์ ๊ฒฐ๊ณผ๋ ์กฐ๊ฑด๋ถ adversarial network๊ฐ ๋ง์ image-to-image ๋ณํ ์์ ,
ํนํ ๊ตฌ์กฐํ๋ ๊ทธ๋ํฝ ์ถ๋ ฅ์ ํฌํจํ๋ ์์ ์ ์ ๋งํ ์ ๊ทผ ๋ฐฉ์์.
์ ๋ ฅ์ด ๋น์ ์์ ์ด๊ฑฐ๋ ์ ๋ ฅ ์ด๋ฏธ์ง ๋ด์ ์ผ๋ถ ์์ญ์ด ๋ง์ด ๋น์ด์์ด ์ด๋ค ์ ๋ณด๋ฅผ ๋ํ๋ด๋ ์ง ๋ชจ๋ธ์ด ์๊ธฐ ์ด๋ ค์ธ ๋ ์คํจํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค.
6. Appendix
๋คํธ์ํฌ์ ๊ตฌ์กฐ์ ํ์ต์ ์ฌ์ฉํ ํ๋ผ๋ฏธํฐ๋ค ๋์์์.
๊ตฌํ
Model
Network Implementation
1. CBR2D : Convolution-BatchNormalization-ReLU
2. DECBR2D : ConvolutionTranspose-BatchNormalization-ReLU (decoder๋ฅผ ์ํ)
Generator
Encoder)
C64-C128-C256-C512-C512-C512-C512-C512
Decoder)
CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
(์ด ๊ฒฝ์ฐ U-Net์ ํํ๋ก i๋ฒ์งธ layer ์งํฉ์ฒด์ 8-i๋ฒ์งธ layer ์งํฉ์ฒด๊ฐ์ skip connection์ด ์กด์ฌ)
class Pix2Pix(nn.Module):
def __init__(self, in_channels=3, out_channels, nker=64, norm="bnorm"):
super(Pix2Pix, self).__init__()
self.enc1 = CBR2d(in_channels, 1 * nker, kernel_size=4, padding=1,
norm=None, relu=0.2, stride=2)
self.enc2 = CBR2d(1 * nker, 2 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc3 = CBR2d(2 * nker, 4 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc4 = CBR2d(4 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc5 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc6 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc7 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.enc8 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.2, stride=2)
self.dec1 = DECBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.0, stride=2)
self.drop1 = nn.Dropout2d(0.5)
self.dec2 = DECBR2d(2 * 8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.0, stride=2)
self.drop2 = nn.Dropout2d(0.5)
self.dec3 = DECBR2d(2 * 8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.0, stride=2)
self.drop3 = nn.Dropout2d(0.5)
self.dec4 = DECBR2d(2 * 8 * nker, 8 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.0, stride=2)
self.dec5 = DECBR2d(2 * 8 * nker, 4 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.0, stride=2)
self.dec6 = DECBR2d(2 * 4 * nker, 2 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.0, stride=2)
self.dec7 = DECBR2d(2 * 2 * nker, 1 * nker, kernel_size=4, padding=1,
norm=norm, relu=0.0, stride=2)
self.dec8 = DECBR2d(2 * 1 * nker, out_channels, kernel_size=4, padding=1,
norm=None, relu=None, stride=2)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)
enc5 = self.enc5(enc4)
enc6 = self.enc6(enc5)
enc7 = self.enc7(enc6)
enc8 = self.enc8(enc7)
dec1 = self.dec1(enc8)
drop1 = self.drop1(dec1)
cat2 = torch.cat((drop1, enc7), dim=1)
dec2 = self.dec2(cat2)
drop2 = self.drop2(dec2)
cat3 = torch.cat((drop2, enc6), dim=1)
dec3 = self.dec3(cat3)
drop3 = self.drop3(dec3)
cat4 = torch.cat((drop3, enc5), dim=1)
dec4 = self.dec4(cat4)
cat5 = torch.cat((dec4, enc4), dim=1)
dec5 = self.dec5(cat5)
cat6 = torch.cat((dec5, enc3), dim=1)
dec6 = self.dec6(cat6)
cat7 = torch.cat((dec6, enc2), dim=1)
dec7 = self.dec7(cat7)
cat8 = torch.cat((dec7, enc1), dim=1)
dec8 = self.dec8(cat8)
x = torch.tanh(dec8)
return x
- input_channel ์๊ฐ 3์ด๊ธฐ ๋๋ฌธ์ default๋ฅผ 3์ผ๋ก ์งํ
- pix2pix์์ encoder์ ์ฒซ ํํ๊ฐ C64์ด๊ณ 2๋ฐฐ์ฉ ๋์ด๋๋ ํํ์ด๊ธฐ ๋๋ฌธ์ default nker๋ฅผ 64๋ก ์ค์ ํ๊ณ 2๋ฐฐ์ฉ ๋๋ ค์ฃผ๋ ํํ๋ก layer๋ค์ ๊ตฌํ
- size๋ฅผ 1/2์ฉ์ผ๋ก ์ค์ด๊ธฐ ์ํด padding์ผ๋ก 1์ ์ฌ์ฉ
- decoder์ 3๋ฒ์งธ layer ์งํฉ์ฒด๊น์ง dropout(0.5)
- forward()์์๋ ์์ ์ ์ํ encoder์ decoder๋ฅผ ์์ฐจ์ ์ผ๋ก ์งํ
: skip connection์ ๊ตฌํํ๊ธฐ ์ํด์ torch.cat์ ์ด์ฉ -> i ๋ฒ์งธ์ 8-i ๋ฒ์งธ์ layer ์งํฉ์ฒด์ ๊ฒฐ๊ณผ๋ฅผ concatenate
Discriminator
C64-C128-C256-C512์ ๊ตฌ์กฐ
class Discriminator(nn.Module):
def __init__(self, in_channels, out_channels, nker=64, norm="bnorm"):
super(Discriminator, self).__init__()
self.enc1 = CBR2d(1 * in_channels, 1 * nker, kernel_size=4, stride=2,
padding=1, norm=None, relu=0.2, bias=False)
self.enc2 = CBR2d(1 * nker, 2 * nker, kernel_size=4, stride=2,
padding=1, norm=norm, relu=0.2, bias=False)
self.enc3 = CBR2d(2 * nker, 4 * nker, kernel_size=4, stride=2,
padding=1, norm=norm, relu=0.2, bias=False)
self.enc4 = CBR2d(4 * nker, out_channels, kernel_size=4, stride=2,
padding=1, norm=norm, relu=0.2, bias=False)
def forward(self, x):
x = self.enc1(x)
x = self.enc2(x)
x = self.enc3(x)
x = self.enc4(x)
x = torch.sigmoid(x)
return x
- ๋ง์ง๋ง layer์ ๋์ sigmoid๊ฐ ๋ฑ์ฅ
- leakyReLU 0.2
Train
# train ์งํํ๊ธฐ
for epoch in range(st_epoch + 1, num_epoch + 1):
netG.train()
netD.train()
generatorL1Loss = []
generatorGanLoss = []
discriminatorRealLoss = []
discriminatorFakeLoss = []
for batch, data in enumerate(dataLoader, 1):
# forward netG
label = data['label'].to(device)
input = data['input'].to(device)
output = netG(input)
# backward netD
set_requires_grad(netD, True)
optimD.zero_grad()
real = torch.cat([input, label], dim=1)
fake = torch.cat([input, output], dim=1)
predictReal = netD(real)
predictFake = netD(fake.detach())
realLossD = ganLoss(predictReal, torch.ones_like(predictReal))
fakeLossD = ganLoss(predictFake, torch.zeros_like(predictFake))
discriminatorLoss = 0.5 * (realLossD + fakeLossD)
discriminatorLoss.backward()
optimD.step()
# backward netG
set_requires_grad(netD, False)
optimG.zero_grad()
fake = torch.cat([input, output], dim=1)
predictFake = netD(fake)
ganLossG = ganLoss(predictFake, torch.ones_like(predictFake))
l1LossG = l1Loss(output, label)
generatorLoss = ganLossG + wgt * l1LossG
generatorLoss.backward()
optimG.step()
# loss ๊ณ์ฐ
generatorL1Loss += [l1LossG.item()]
generatorGanLoss += [ganLossG.item()]
discriminatorRealLoss += [realLossD.item()]
discriminatorFakeLoss += [fakeLossD.item()]
# ๋งค batch ๋ง๋ค์ loss ์ถ๋ ฅ
print("TRAIN: EPOCH %04d / %04d | BATCH %04d / %04d | "
"GENERATOR L1 LOSS %.4f | GENERATOR GAN LOSS %.4f | "
"DISCRIMINATOR REAL LOSS: %.4f | DISCRIMINATOR FAKE LOSS: %.4f" %
(epoch, num_epoch, batch, num_batch_train,
np.mean(loss_G_l1_train), np.mean(loss_G_gan_train),
np.mean(loss_D_real_train), np.mean(loss_D_fake_train)))
- mini-batch๋ฅผ ์ด์ฉ
- generator๋ฅผ ์ด์ฉํด ์ด๋ฏธ์ง๋ฅผ ์์ฑํด ๋ด๊ณ , netD๋ ๊ทธ ์ด๋ฏธ์ง๋ฅผ ์ด์ฉํด ์ง์์ฌ๋ถ๋ฅผ ์ํํ ๋ค back propagation์ ์งํํ์ฌ gradient descent ๋ฐฉํฅ์ผ๋ก ํ์ตํ๊ณ , ์ดํ์ netG๊ฐ back propagation์ ์งํํ์ฌ gradient descent ๋ฐฉํฅ์ผ๋ก ํ์ต์ด ์งํ
- backward netD์์๋ง fake.detach()๋ฅผ ํด์ค ์ด์ ๋ discriminator๋ฅผ back propagtion ํ ๋ ์ํ์ง ์๋ generator์ back propagation์ ๋ง๊ธฐ ์ํจ
๋ ํผ๋ฐ์ค
: https://medium.com/humanscape-tech/ml-practice-pix2pix-1-d89d1a011dbf
[ML Practice] pix2pix(1)
ํด๋จผ์ค์ผ์ดํ Software engineer Covy์ ๋๋ค.
medium.com
'AI > ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ Paper Review' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
NeRF | ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (0) | 2024.05.25 |
---|---|
DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation (0) | 2024.05.21 |
Wasserstein GAN : arXiv 2017 | ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (0) | 2024.03.03 |
DCGAN : ICLR 2016 (0) | 2024.01.28 |
Generative Adversarial Nets : arXive 2014 (0) | 2024.01.15 |
- Total
- Today
- Yesterday
- C์ธ์ด
- ๋ ผ๋ฌธ๋ฆฌ๋ทฐ
- Aimers
- AI์ปจํผ๋ฐ์ค
- ํ์ด์ฌ
- AIRUSH2023
- SKTECHSUMMIT
- SQL
- Paper review
- gs๋ ผ๋ฌธ
- ์ฝํ ์ค๋น
- gan
- ์คํ ์ด๋ธ๋ํจ์
- ํ๋ก๊ทธ๋๋จธ์ค
- ์ฝ๋ฉ์๋ฌ
- ์ปดํจํฐ๋น์
- ๋ ผ๋ฌธ
- dreambooth
- 2d-gs
- ์ฝ๋ฉ๊ณต๋ถ
- CLOVAX
- AIRUSH
- ๋ ผ๋ฌธ์ฝ๊ธฐ
- ํ์ด์ฌ์ฝํ
- lgaimers
- MYSQL
- Gaussian Splatting
- ๋๋ฆผ๋ถ์ค
- ํ ํฌ์๋ฐ
- 3d-gs
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |