ํ‹ฐ์Šคํ† ๋ฆฌ ๋ทฐ

AI/๋…ผ๋ฌธ ๋ฆฌ๋ทฐ Paper Review

Pix2Pix : CVPR 2017

ํ•ด๋“œ์œ„๊ทธ 2024. 2. 7. 19:48

Phillip Isola

 

1. Introduction

์„œ๋กœ ์—ฐ๊ด€๋œ pair dataset ํ•„์š”

 

Image-to-Image Translation์€ ์ด๋ฏธ์ง€๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„์„œ ๋˜ ๋‹ค๋ฅธ ์ด๋ฏธ์ง€๋ฅผ ์ถœ๋ ฅ์œผ๋กœ ๋ฐ˜ํ™˜ํ•˜๋Š” Task๋ฅผ ๋œปํ•œ๋‹ค.

 

๋ณธ ๋…ผ๋ฌธ์€ Image-to-Image Translation์— ์ ํ•ฉํ•œ cGAN์„ ๊ธฐ๋ฐ˜์œผ๋กœํ•˜๋ฉฐ ๋‹ค์–‘ํ•œ Task์—์„œ ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ด๋Š” ํ”„๋ ˆ์ž„์›Œํฌ pix2pix๋ฅผ ๋‹ค๋ฃฌ๋‹ค. ์ž๋™ ์–ธ์–ด ๋ฒˆ์—ญ์ด ๊ฐ€๋Šฅํ•œ ๊ฒƒ์ฒ˜๋Ÿผ ์ž๋™ image-to-image ๋ณ€ํ™˜ ๋˜ํ•œ ์ถฉ๋ถ„ํ•œ ํ•™์Šต ๋ฐ์ดํ„ฐ๊ฐ€ ์ฃผ์–ด์ง„๋‹ค๋ฉด ํ•œ ์žฅ๋ฉด์˜ ํ‘œํ˜„์„ ๋‹ค๋ฅธ ์žฅ๋ฉด์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์ž‘์—…์œผ๋กœ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋‹ค.

DCGAN๊ณผ ๋‹ค๋ฅธ์ ์€ Generator(G)์˜ input์ด random vector๊ฐ€ ์•„๋‹ˆ๋ผ condition input ๋ผ๋Š” ์ ์ด๋‹ค. 

๋ณธ ๋…ผ๋ฌธ์—์„œ ์šฐ๋ฆฌ์˜ ๋ชฉํ‘œ๋Š” ์ด๋Ÿฌํ•œ ๋ชจ๋“  ๋ฌธ์ œ์— ๋Œ€ํ•œ ๊ณตํ†ต๋œ ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ๊ฐœ๋ฐœํ•˜๋Š” ๊ฒƒ์ด๋‹ค. -> ๋‹ค์–‘ํ•œ ๊ณผ์ œ ํ•ด๊ฒฐ๊ฐ€๋Šฅ!

์ด๋ฏธ์ง€ ์˜ˆ์ธก์—์„œ ์‚ฌ์šฉ๋˜๋Š” CNN์—์„œ L1, L2์™€ ๊ฐ™์€ ์œ ํด๋ฆฌ๋””์•ˆ ๊ฑฐ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•œ loss๋Š” ํ๋ฆฟํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•˜๋‚˜, GAN loss๋ฅผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•œ๋‹ค๋ฉด ์„ ๋ช…ํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

 

2. Related work

Structed losses for image modeling, Conditional GANs ๊ด€๋ จ ๋‚ด์šฉ

* ๋ณธ ๋…ผ๋ฌธ์˜ ์—ฐ๊ตฌ๋Š” ํŠน์ • ๋ชฉ์ ์ด ์—†๋Š” ๊ฒƒ์ด ํŠน์ง•์ž„.

 

3. Method

CGAN

GANs๋Š” ๋žœ๋ค ๋…ธ์ด์ฆˆ ๋ฒกํ„ฐ ๋ฅผ ์ถœ๋ ฅ ์ด๋ฏธ์ง€ ๋กœ ๋งคํ•‘ํ•˜๋Š” ๋งคํ•‘ (์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ชจ๋ธ์ด๋‹ค. ๋Œ€์กฐ์ ์œผ๋กœ ์กฐ๊ฑด๋ถ€ GANs๋Š” ์กฐ๊ฑด์— ํ•ด๋‹นํ•˜๋Š” ์ด๋ฏธ์ง€ ์™€ ๋žœ๋ค ๋…ธ์ด์ฆˆ ๋ฒกํ„ฐ ๋ฅผ ์ถœ๋ ฅ ์ด๋ฏธ์ง€ ๋กœ ๋งคํ•‘ํ•˜๋Š” ๋งคํ•‘ (์„ ํ•™์Šตํ•ด ์ˆ˜ํ–‰ํ•œ๋‹ค.

์œ„ ๊ทธ๋ฆผ์—์„œ๋Š” (edge)๋ฅผ ์กฐ๊ฑด์œผ๋กœ ๋ฐ›์•„ ์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ณ , ํŒ๋ณ„์ž ๋˜ํ•œ ์ด (edge)์™€ ์ƒ์„ฑ ์ด๋ฏธ์ง€๋ฅผ ๋ฐ›์•„ ํŒ๋ณ„ํ•˜๋ฉด์„œ ๋ชจ๋ธ์„ ๋ฐœ์ „์‹œํ‚ค๊ฒŒ ๋œ๋‹ค.

 

3.1 Objective

GAN์˜ objective loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

Conditional GAN์€ ๊ธฐ์กด GAN์—์„œ condition x๋งŒ ์ถ”๊ฐ€๋œ ๊ฒƒ.

์›๋ž˜ cGAN์—์„œ๋Š” latent vector z์— ์ปจ๋””์…˜ ๋ฒกํ„ฐ z๋ฅผ ๊ฐ€ํ•˜๋Š” ๋ฐฉ์‹์ด์—ˆ๋Š”๋ฐ,

์—ฌ๊ธฐ์„œ ์ปจ๋””์…˜๋ฒกํ„ฐ ๋Œ€์‹ ์— x๋ฅผ ์ž…๋ ฅ ์˜์ƒ์œผ๋กœ ๋ณธ๋‹ค. : ์œ„์˜ ์‹

Pix2Pix๋Š” {๋ณ€ํ™˜ํ•  ์˜์ƒ x, ๋ณ€ํ™˜ ๊ฒฐ๊ณผ ์˜์ƒ y} ๋ฅผ ์Œ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์œ„ ์‹๊ณผ ๊ฐ™์€ ๋ชฉํ‘œํ•จ์ˆ˜๊ฐ€ ๋„์ถœ๋จ.

 

"Context Encoder" ๋…ผ๋ฌธ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ›์•„๋“ค์ž„ -> loss ํ•จ์ˆ˜๋ฅผ ์„ค๊ณ„ํ•  ๋•Œ traditional loss((ex: L2 loss))๋ฅผ GAN loss์™€ ์„ž์–ด ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ํšจ๊ณผ์ ์ด๋ผ๋Š” ๊ฒƒ.

์‹คํ—˜ ๊ฒฐ๊ณผ L1 loss์™€ L2 loss ์ค‘ L1์ด ์ด๋ฏธ์ง€์˜ blurํ•จ์ด ๋œ ํ•ด L1 loss๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ๋กœ ๊ฒฐ์ •.

L1 loss
์ตœ์ข… ๋ชฉํ‘œ ํ•จ์ˆ˜
๊ฐ Loss์— ๋”ฐ๋ฅธ ๊ฒฐ๊ณผ

L1+cCAN์ด ์ ค ํšจ๊ณผ์ ์ธ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

Random noise Z ์‚ฌ์šฉ X?

G๊ฐ€ z๋ฅผ ๋ฌด์‹œํ•˜๋„๋ก ํ•™์Šต์ด ๋˜๊ธฐ ๋•Œ๋ฌธ์—, random noise Z๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค.

Z๊ฐ€ ์—†๊ธฐ ๋•Œ๋ฌธ์— mapping ๊ณผ์ •์—์„œ stochastic((๋ฌด์ž‘์œ„))ํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋‚ด์ง€ ๋ชปํ•˜๊ณ  deterministicํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋‚ด๋ณด๋‚ด๋Š” ๋ฌธ์ œ๊ฐ€ ์žˆ์ง€๋งŒ, ๋…ผ๋ฌธ์—์„œ๋Š”์ด๋ฅผ future work๋กœ์จ ๋‚จ๊ฒจ๋‘์—ˆ๋‹ค.

๋Œ€์‹  ์กฐ๊ธˆ์˜ ๋ฌด์ž‘์œ„์„ฑ์ด ์ƒ๊ธฐ๋„๋ก layer์— dropout์„ ์ ์šฉํ•˜์˜€๋‹ค.

 

3.2 Network architectures

3.2.1 Generator with skips

Generator : U-Net ๊ธฐ๋ฐ˜

encoder-decoder ๊ตฌ์กฐ์—์„œ ์˜์ƒ ํฌ๊ธฐ๋ฅผ ์ค„์˜€๋‹ค๊ฐ€ ๋‹ค์‹œ ํ‚ค์šฐ๋Š” ๊ณผ์ •์—์„œ detail์ด ์‚ฌ๋ผ์ง€๋ฉด์„œ ์˜์ƒ์ด blurํ•ด์ง€๋Š” ๋ฌธ์ œ๋ฅผ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด, ์˜ค๋ฅธ์ชฝ์ฒ˜๋Ÿผ ๋‹ค์‹œ skip connection์„ ๊ฐ€์ง„๋‹ค.

๋Œ€์นญ๊ตฌ์กฐ

U-Net์ด๋ž€ ์ด๋ฆ„์ด ๋ถ™์€ ์ด์œ ๋Š” ์œ„์™€ ๊ฐ™์ด U์ž ํ˜•ํƒœ๋กœ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ.

๋ ˆ์ด์–ด๊ฐ€ ๊นŠ์–ด์งˆ ์ˆ˜๋ก ์„ธ๋ฐ€ํ•œ ์ •๋ณด๋Š” ์–ด๋””์„ ๊ฐ€ ๋ฐ›์ง€ ์•Š๋Š” ์ด์ƒ ๋‚จ์ง€ ์•Š์Œ. (skip connection ์—ญํ• )

์ธ์ฝ”๋”-๋””์ฝ”๋” VS U-Net

๋” ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ž„!

 

3.2.2 Markovian discriminator (PatchGAN)

Disciminator: PatchGAN ์‚ฌ์šฉ

L1 Loss์˜ ํŠน์ง•: Low-frequency ์„ฑ๋ถ„์„ ์ž˜ ๊ฒ€์ถœ ํ•จ.

* Frequency : ์ด๋ฏธ์ง€์—์„œ ํ”ฝ์…€ ๋ณ€ํ™”์˜ ์ •๋„
- low-frequency : ์‚ฌ๋ฌผ์— ๋Œ€ํ•ด์„œ๋Š” ์‚ฌ๋ฌผ ๋‚ด์—์„œ๋Š” ์ƒ‰ ๋ณ€ํ™”๊ฐ€ ํฌ์ง€ ์•Š์Œ.
- high-frequency : ์‚ฌ๋ฌผ ๊ฒฝ๊ณ„( : edge)์—์„œ๋Š” ์ƒ‰์ด ๊ธ‰๊ฒฉํ•˜๊ฒŒ ๋ณ€ํ•จ.

 

์ฆ‰, L1 loss๋ฅผ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ blurryํ•˜์ง€๋งŒ low-frequency ์„ฑ๋ถ„๋“ค์„ ์ž˜ ๊ฒ€์ถœํ•ด๋‚ด๋ฏ€๋กœ ์ด๋Š” ๊ทธ๋Œ€๋กœ ๋‘๊ณ  Discriminator์—์„œ high-frequency์˜ ๊ฒ€์ถœ์„ ์œ„ํ•œ ๋ชจ๋ธ๋ง์„ ์ง„ํ–‰ํ•˜๊ฒŒ ๋œ๋‹ค.

=> ์ด๋ฏธ์ง€ ์ „์ฒด๊ฐ€ ํ•„์š”์—†๊ณ , local image patch ๋ฅผ ์‚ฌ์šฉํ•ด high-frequency๋ฅผ ํŒ๋ณ„ํ•ด๋„ ๊ดœ์ฐฎ์Œ.

 

PatchGAN:

์ „์ฒด ์˜์—ญ์ด ์•„๋‹ˆ๋ผ, ํŠน์ • ํฌ๊ธฐ์˜ patch ๋‹จ์œ„๋กœ ์ง„์งœ/๊ฐ€์งœ๋ฅผ ํŒ๋ณ„ํ•˜๊ณ , ๊ทธ ๊ฒฐ๊ณผ์— ํ‰๊ท ์„ ์ทจํ•˜๋Š” ๋ฐฉ์‹.

 

patch ํฌ๊ธฐ์— ๋”ฐ๋ฅธ ์‹คํ—˜์„ ์ง„ํ–‰ํ•จ.

70x70 ์ผ ๋•Œ, ๊ฐ€์žฅ ํšจ๊ณผ์ ์ธ ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์˜ค๋Š” ๊ฒƒ์„ ํ™•์ธ.

=> ์ „์ฒด ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํŒ๋ณ„๋ณด๋‹ค, ์ ์ ˆํ•œ ํฌ๊ธฐ์˜ patch๋ฅผ ์ •ํ•˜๊ณ  ๊ทธ patch๋“ค์ด ๋Œ€๋ถ€๋ถ„ ์ง„์งœ์ธ ์ชฝ์œผ๋กœ ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์˜ค๋„๋ก ํ•™์Šต์„ ์ง„ํ–‰ํ•˜๋Š” ๊ฒƒ์ด ํšจ์œจ์ ์ด๋‹ค.

 

3.3 Optimization and inference

ํŠน์ด์‚ฌํ•ญ ๋ช‡ ๊ฐ€์ง€.

- train:

1. G๋ฅผ ํ•™์Šตํ•  ๋•Œ, log(x,G(x,z)) ๋ฅผ ์ตœ๋Œ€ํ•œ ํ•™์Šตํ•˜๋„๋ก ํ•จ.

2. D๋ฅผ ํ•™์Šตํ•  ๋•Œ, loss๋ฅผ 2๋กœ ๋‚˜๋ˆ ์„œ ํ•™์Šต. D๋ฅผ G๋ณด๋‹ค ํ•™์Šต ์†๋„๋ฅผ ๋Š๋ฆฌ๊ฒŒํ•˜๊ธฐ ์œ„ํ•ด์„œ์ด๋‹ค.

- test:

1. dropout์„ ์“ด๋‹ค.

2. batch normalization์€ test batch์˜ statistics๋ฅผ ์‚ฌ์šฉํ•จ.

 

4. Experiments

4.1 Evaluation metrics

๊ฒฐ๊ณผ๋ฅผ ์ „์ฒด์ ์œผ๋กœ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด 2๊ฐ€์ง€ ์ „๋žต์„ ์‚ฌ์šฉ.

 

1. Amazon Mechanical Turk(AMT)๋กœ ‘real vs fake’์— ๋Œ€ํ•ด ์‚ฌ๋žŒ์ด ํ‰๊ฐ€ํ•˜๋Š” ์ธ์‹ํ…Œ์ŠคํŠธ.

์‚ฌ๋žŒ์ด ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€๋ฅผ ๋ณด๊ณ  ์ง„์งœ์ธ์ง€ ๊ฐ€์งœ์ธ์ง€ ํŒ๋‹จํ•˜๋Š” ์‹คํ—˜, ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๊ฐ€ ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ์‚ฌ๋žŒ๋“ค์—๊ฒŒ ์ง„์งœ ์ด๋ฏธ์ง€์ฒ˜๋Ÿผ ๋ณด์—ฌ ์‚ฌ๋žŒ๋“ค์„ ์†์ผ ์ˆ˜ ์žˆ์—ˆ๋Š”์ง€์— ๋Œ€ํ•œ ํผ์„ผํŠธ๋ฅผ ํ†ตํ•ด ๋ถ„์„.

 

2. ‘FCN-score’๋กœ ๊ธฐ์กด์˜ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ํ•ฉ์„ฑ๋œ ์ด๋ฏธ์ง€ ์•ˆ์˜ ๋ฌผ์ฒด๋ฅผ ์ธ์‹ํ•  ์ˆ˜ ์žˆ์„ ์ •๋„๋กœ ํ•ฉ์„ฑ๋œ ์ด๋ฏธ์ง€๊ฐ€ ํ˜„์‹ค์ ์ธ์ง€ ์—ฌ๋ถ€๋ฅผ ์ธก์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•.

๊ธฐ์กด์˜ semantic segmentation ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ์ƒ์„ฑ ์ด๋ฏธ์ง€ ๋‚ด์˜ object๋“ค์„ ์–ผ๋งˆ๋‚˜ ์ •ํ™•ํ•˜๊ฒŒ ํด๋ž˜์Šค ๋ณ„๋กœ ํ”ฝ์…€์— ๋‚˜ํƒ€๋‚ด๋Š” ์ง€๋ฅผ ์‹คํ—˜,์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๊ณผ ํ˜„์‹ค์˜ ์ด๋ฏธ์ง€์™€ ์œ ์‚ฌํ•  ์ˆ˜๋ก segmentation ๋ชจ๋ธ ๋˜ํ•œ ๋” ์ •ํ™•ํ•˜๊ฒŒ ์ด๋ฏธ์ง€ ๋‚ด์˜ object ๋“ค์„ segmentation ํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋ผ๋Š” ์•„์ด๋””์–ด๋กœ segmentation ๊ฒฐ๊ณผ์˜ pixel, class ๋ณ„ ์ •ํ™•๋„ ๋“ฑ์„ ๋ถ„์„.

4.4 From PixelGANs to PatchGANs to ImageGANs

 

์ง€๋„ -> ํ•ญ๊ณต์‚ฌ์ง„ / ํ•ญ๊ณต์‚ฌ์ง„ -> ์ง€๋„

 

์˜์ƒ์— ๋Œ€ํ•œ label ๋‹ฌ์•„์ฃผ๋Š” semantic segmentation ์ˆ˜ํ–‰

=> L1์ด ์ •ํ™•๋„๊ฐ€ ๊ฐ€์žฅ ๋†’์€ ๋ฐฉ๋ฒ•

 

5. Conclusion

๋ณธ ๋…ผ๋ฌธ์˜ ๊ฒฐ๊ณผ๋Š” ์กฐ๊ฑด๋ถ€ adversarial network๊ฐ€ ๋งŽ์€ image-to-image ๋ณ€ํ™˜ ์ž‘์—…,

ํŠนํžˆ ๊ตฌ์กฐํ™”๋œ ๊ทธ๋ž˜ํ”ฝ ์ถœ๋ ฅ์„ ํฌํ•จํ•˜๋Š” ์ž‘์—…์— ์œ ๋งํ•œ ์ ‘๊ทผ ๋ฐฉ์‹์ž„.

์‹คํŒจ์‚ฌ๋ก€

์ž…๋ ฅ์ด ๋น„์ •์ƒ์ ์ด๊ฑฐ๋‚˜ ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๋‚ด์˜ ์ผ๋ถ€ ์˜์—ญ์ด ๋งŽ์ด ๋น„์–ด์žˆ์–ด ์–ด๋–ค ์ •๋ณด๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์ง€ ๋ชจ๋ธ์ด ์•Œ๊ธฐ ์–ด๋ ค์šธ ๋•Œ ์‹คํŒจํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ๋‹ค.

 

6. Appendix

๋„คํŠธ์›Œํฌ์˜ ๊ตฌ์กฐ์™€ ํ•™์Šต์— ์‚ฌ์šฉํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค ๋‚˜์™€์žˆ์Œ.

 

๊ตฌํ˜„

Model

Network Implementation

1. CBR2D : Convolution-BatchNormalization-ReLU

2. DECBR2D : ConvolutionTranspose-BatchNormalization-ReLU (decoder๋ฅผ ์œ„ํ•œ)

Generator

Encoder)
C64-C128-C256-C512-C512-C512-C512-C512

Decoder)
CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
(์ด ๊ฒฝ์šฐ U-Net์˜ ํ˜•ํƒœ๋กœ i๋ฒˆ์งธ layer ์ง‘ํ•ฉ์ฒด์™€ 8-i๋ฒˆ์งธ layer ์ง‘ํ•ฉ์ฒด๊ฐ„์˜ skip connection์ด ์กด์žฌ)

class Pix2Pix(nn.Module):
    def __init__(self, in_channels=3, out_channels, nker=64, norm="bnorm"):
        super(Pix2Pix, self).__init__()

        self.enc1 = CBR2d(in_channels, 1 * nker, kernel_size=4, padding=1,
                          norm=None, relu=0.2, stride=2)

        self.enc2 = CBR2d(1 * nker, 2 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc3 = CBR2d(2 * nker, 4 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc4 = CBR2d(4 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc5 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc6 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc7 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.enc8 = CBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                          norm=norm, relu=0.2, stride=2)

        self.dec1 = DECBR2d(8 * nker, 8 * nker, kernel_size=4, padding=1,
                            norm=norm, relu=0.0, stride=2)
        self.drop1 = nn.Dropout2d(0.5)

        self.dec2 = DECBR2d(2 * 8 * nker, 8 * nker, kernel_size=4, padding=1,
                            norm=norm, relu=0.0, stride=2)
        self.drop2 = nn.Dropout2d(0.5)

        self.dec3 = DECBR2d(2 * 8 * nker, 8 * nker, kernel_size=4, padding=1,
                            norm=norm, relu=0.0, stride=2)
        self.drop3 = nn.Dropout2d(0.5)

        self.dec4 = DECBR2d(2 * 8 * nker, 8 * nker, kernel_size=4, padding=1,
                            norm=norm, relu=0.0, stride=2)

        self.dec5 = DECBR2d(2 * 8 * nker, 4 * nker, kernel_size=4, padding=1,
                            norm=norm, relu=0.0, stride=2)

        self.dec6 = DECBR2d(2 * 4 * nker, 2 * nker, kernel_size=4, padding=1,
                            norm=norm, relu=0.0, stride=2)

        self.dec7 = DECBR2d(2 * 2 * nker, 1 * nker, kernel_size=4, padding=1,
                            norm=norm, relu=0.0, stride=2)

        self.dec8 = DECBR2d(2 * 1 * nker, out_channels, kernel_size=4, padding=1,
                            norm=None, relu=None, stride=2)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)
        enc6 = self.enc6(enc5)
        enc7 = self.enc7(enc6)
        enc8 = self.enc8(enc7)

        dec1 = self.dec1(enc8)
        drop1 = self.drop1(dec1)

        cat2 = torch.cat((drop1, enc7), dim=1)
        dec2 = self.dec2(cat2)
        drop2 = self.drop2(dec2)

        cat3 = torch.cat((drop2, enc6), dim=1)
        dec3 = self.dec3(cat3)
        drop3 = self.drop3(dec3)

        cat4 = torch.cat((drop3, enc5), dim=1)
        dec4 = self.dec4(cat4)

        cat5 = torch.cat((dec4, enc4), dim=1)
        dec5 = self.dec5(cat5)

        cat6 = torch.cat((dec5, enc3), dim=1)
        dec6 = self.dec6(cat6)

        cat7 = torch.cat((dec6, enc2), dim=1)
        dec7 = self.dec7(cat7)

        cat8 = torch.cat((dec7, enc1), dim=1)
        dec8 = self.dec8(cat8)

        x = torch.tanh(dec8)

        return x

- input_channel ์ˆ˜๊ฐ€ 3์ด๊ธฐ ๋•Œ๋ฌธ์— default๋ฅผ 3์œผ๋กœ ์ง„ํ–‰

- pix2pix์—์„œ encoder์˜ ์ฒซ ํ˜•ํƒœ๊ฐ€ C64์ด๊ณ  2๋ฐฐ์”ฉ ๋Š˜์–ด๋‚˜๋Š” ํ˜•ํƒœ์ด๊ธฐ ๋•Œ๋ฌธ์— default nker๋ฅผ 64๋กœ ์„ค์ •ํ•˜๊ณ  2๋ฐฐ์”ฉ ๋Š˜๋ ค์ฃผ๋Š” ํ˜•ํƒœ๋กœ layer๋“ค์„ ๊ตฌํ˜„

- size๋ฅผ 1/2์”ฉ์œผ๋กœ ์ค„์ด๊ธฐ ์œ„ํ•ด padding์œผ๋กœ 1์„ ์‚ฌ์šฉ

- decoder์˜ 3๋ฒˆ์งธ layer ์ง‘ํ•ฉ์ฒด๊นŒ์ง€ dropout(0.5)

- forward()์—์„œ๋Š” ์•ž์„œ ์ •์˜ํ•œ encoder์™€ decoder๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ์ง„ํ–‰

: skip connection์„ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด์„œ torch.cat์„ ์ด์šฉ -> i ๋ฒˆ์งธ์™€ 8-i ๋ฒˆ์งธ์˜ layer ์ง‘ํ•ฉ์ฒด์˜ ๊ฒฐ๊ณผ๋ฅผ concatenate

 

Discriminator

C64-C128-C256-C512์˜ ๊ตฌ์กฐ

class Discriminator(nn.Module):
    def __init__(self, in_channels, out_channels, nker=64, norm="bnorm"):
        super(Discriminator, self).__init__()

        self.enc1 = CBR2d(1 * in_channels, 1 * nker, kernel_size=4, stride=2,
                          padding=1, norm=None, relu=0.2, bias=False)

        self.enc2 = CBR2d(1 * nker, 2 * nker, kernel_size=4, stride=2,
                          padding=1, norm=norm, relu=0.2, bias=False)

        self.enc3 = CBR2d(2 * nker, 4 * nker, kernel_size=4, stride=2,
                          padding=1, norm=norm, relu=0.2, bias=False)

        self.enc4 = CBR2d(4 * nker, out_channels, kernel_size=4, stride=2,
                          padding=1, norm=norm, relu=0.2, bias=False)

    def forward(self, x):
        x = self.enc1(x)
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)

        x = torch.sigmoid(x)

        return x

- ๋งˆ์ง€๋ง‰ layer์˜ ๋์— sigmoid๊ฐ€ ๋“ฑ์žฅ

- leakyReLU 0.2

 

Train

# train ์ง„ํ–‰ํ•˜๊ธฐ
for epoch in range(st_epoch + 1, num_epoch + 1):
    netG.train()
    netD.train()

    generatorL1Loss = []
    generatorGanLoss = []
    discriminatorRealLoss = []
    discriminatorFakeLoss = []

    for batch, data in enumerate(dataLoader, 1):
        # forward netG
        label = data['label'].to(device)
        input = data['input'].to(device)
        output = netG(input)

        # backward netD
        set_requires_grad(netD, True)
        optimD.zero_grad()

        real = torch.cat([input, label], dim=1)
        fake = torch.cat([input, output], dim=1)

        predictReal = netD(real)
        predictFake = netD(fake.detach())

        realLossD = ganLoss(predictReal, torch.ones_like(predictReal))
        fakeLossD = ganLoss(predictFake, torch.zeros_like(predictFake))
        discriminatorLoss = 0.5 * (realLossD + fakeLossD)

        discriminatorLoss.backward()
        optimD.step()

        # backward netG
        set_requires_grad(netD, False)
        optimG.zero_grad()

        fake = torch.cat([input, output], dim=1)
        predictFake = netD(fake)

        ganLossG = ganLoss(predictFake, torch.ones_like(predictFake))
        l1LossG = l1Loss(output, label)
        generatorLoss = ganLossG + wgt * l1LossG

        generatorLoss.backward()
        optimG.step()

        # loss ๊ณ„์‚ฐ
        generatorL1Loss += [l1LossG.item()]
        generatorGanLoss += [ganLossG.item()]
        discriminatorRealLoss += [realLossD.item()]
        discriminatorFakeLoss += [fakeLossD.item()]
        
        # ๋งค batch ๋งˆ๋‹ค์˜ loss ์ถœ๋ ฅ
        print("TRAIN: EPOCH %04d / %04d | BATCH %04d / %04d | "
              "GENERATOR L1 LOSS %.4f | GENERATOR GAN LOSS %.4f | "
              "DISCRIMINATOR REAL LOSS: %.4f | DISCRIMINATOR FAKE LOSS: %.4f" %
              (epoch, num_epoch, batch, num_batch_train,
               np.mean(loss_G_l1_train), np.mean(loss_G_gan_train),
               np.mean(loss_D_real_train), np.mean(loss_D_fake_train)))

- mini-batch๋ฅผ ์ด์šฉ

- generator๋ฅผ ์ด์šฉํ•ด ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ด ๋‚ด๊ณ , netD๋Š” ๊ทธ ์ด๋ฏธ์ง€๋ฅผ ์ด์šฉํ•ด ์ง„์œ„์—ฌ๋ถ€๋ฅผ ์ˆ˜ํ–‰ํ•œ ๋’ค back propagation์„ ์ง„ํ–‰ํ•˜์—ฌ gradient descent ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šตํ•˜๊ณ , ์ดํ›„์— netG๊ฐ€ back propagation์„ ์ง„ํ–‰ํ•˜์—ฌ gradient descent ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šต์ด ์ง„ํ–‰

- backward netD์—์„œ๋งŒ fake.detach()๋ฅผ ํ•ด์ค€ ์ด์œ ๋Š” discriminator๋ฅผ back propagtion ํ•  ๋•Œ ์›ํ•˜์ง€ ์•Š๋Š” generator์˜ back propagation์„ ๋ง‰๊ธฐ ์œ„ํ•จ

 

 

 

๋ ˆํผ๋Ÿฐ์Šค

: https://medium.com/humanscape-tech/ml-practice-pix2pix-1-d89d1a011dbf

 

[ML Practice] pix2pix(1)

ํœด๋จผ์Šค์ผ€์ดํ”„ Software engineer Covy์ž…๋‹ˆ๋‹ค.

medium.com

 

 

๋ฐ˜์‘ํ˜•