14. InfoGAN 이론 및 구현
InfoGAN은 정보 이론(information theory)을 이용하여 무작위로 생성된 이미지에 대한 명확한 의미를 부여할 수 있도록 설계된 모델이다. 기본적인 GAN 구조에 “정보 최대화”(information maximization) 접근 방식을 추가하여, 생성자가 생성하는 이미지에서 어떤 특정한 특징(feature)을 조절할 수 있도록 한다.
InfoGAN의 핵심 개념
InfoGAN의 핵심 아이디어는 생성자 입력의 일부분을 “조건부 정보”(conditional information)로 사용하여, 생성된 이미지에 특정한 특성을 명확하게 표현할 수 있게 하는 것이다. 예를 들어, 이 조건부 정보를 사용하여 생성된 얼굴 이미지에서 표정, 헤어스타일, 안경 유무 등을 명확하게 조절할 수 있다.
GAN의 기본 손실 함수
GAN의 기본 손실 함수는 다음과 같이 두 부분으로 구성된다:
생성자(Generator) 손실 함수 (L_G): $L_G = -\mathbb{E}_{z \sim p_z(z), c \sim p(c)}[\log D(G(z, c))]$
여기서, $G(z, c)$ 는 잠재 공간의 노이즈 $z$와 조건부 정보 $c$ 를 입력으로 받아 생성된 이미지를 의미하며, $D(G(z, c))$ 는 생성된 이미지를 판별자가 진짜로 판별할 확률이다. 생성자의 목표는 이 값을 최대화하여 판별자를 속이는 것.
판별자(Discriminator) 손실 함수 (L_D): $L_D = -\mathbb{E}{x \sim p{\text{data}}(x)}[\log D(x)] - \mathbb{E}_{z \sim p_z(z), c \sim p(c)}[\log (1 - D(G(z, c)))]$
여기서, $x$ 는 실제 데이터, $D(x)$ 는 실제 이미지를 판별자가 진짜로 판별할 확률을 의미한다. 판별자의 목표는 실제 이미지와 생성된 이미지를 정확하게 구분하는 것.
InfoGAN의 정보 최대화 항
InfoGAN의 핵심은 생성된 이미지와 조건부 정보 간의 상호 정보를 최대화하는 추가 항:
- 상호 정보 최대화 항 $(c; G(z, c))$:
여기서, $H(c)$ 는 조건부 정보 $c$ 의 엔트로피, $H(c|G(z, c))$ 는 생성된 이미지가 주어졌을 때 조건부 정보 $c$ 의 조건부 엔트로피. 상호 정보 $I(c; G(z, c))$는 $c$와 $G(z, c)$ 간에 공유되는 정보의 양을 측정하며, InfoGAN은 이 값을 최대화하여 조건부 정보와 생성된 이미지 사이의 의존성을 강화한다.
InfoGAN의 손실 함수는 기본 GAN의 손실 함수에 “상호 정보”(mutual information) 손실을 추가한 형태로 상호 정보 손실은 생성자가 생성한 이미지와 조건부 정보 사이의 상호 정보를 최대화하는 것을 목표로 한다. InfoGAN의 전체 손실 함수는 다음과 같다:
- 전체 손실 함수: $L = L_{GAN} - \lambda I(c;G(z,c))$
여기서, $L_{GAN}$은 기본 GAN의 손실 함수, $I(c;G(z,c))$는 생성된 이미지 $G(z,c)$와 조건부 정보 $c$ 사이의 상호 정보를 나타내며, $\lambda$는 상호 정보 손실의 중요도를 조절하는 하이퍼파라미터로 이 항을 최대화함으로써, InfoGAN은 $c$ 에 의해 제어되는 생성된 이미지 내의 의미 있는 특징을 학습한다.
InfoGAN의 손실 함수는 생성적 적대 신경망(GAN)의 기본 손실 함수에 추가적인 정보 최대화 항을 포함하여 구성된다. 이 추가 항은 생성된 이미지와 그 이미지를 생성할 때 사용된 조건부 정보 간의 상호 정보(mutual information)를 최대화하는 데 목적을 둔다. 이를 통해, InfoGAN은 생성된 이미지 내에서 의미 있는 특징을 학습하고 조절할 수 있게 된다.
InfoGAN의 특징
명확한 특성 조절: InfoGAN은 생성된 이미지의 특정 특성을 조절할 수 있는 강력한 메커니즘을 제공한다. 이를 통해 사용자는 이미지의 다양한 측면을 명확하게 조절할 수 있다.
높은 해석 가능성: InfoGAN은 생성된 이미지의 특성을 조절하는 데 사용되는 조건부 정보를 통해 높은 해석 가능성을 제공한다. 이는 모델의 이해와 분석을 용이하게 한다.
다양한 응용 분야: InfoGAN은 얼굴 이미지 생성, 손글씨 숫자 변형, 객체의 특정 특성 변경 등 다양한 응용 분야에서 사용될 수 있다.
결론
InfoGAN은 정보 최대화 접근 방식을 통해 생성된 이미지의 특정 특성을 명확하게 조절할 수 있는 능력을 GAN에 추가한다. 이를 통해 생성된 이미지의 다양성과 해석 가능성을 크게 향상시키며, 다양한 응용 분야에서의 활용 가능성을 열어준다. InfoGAN은 딥러닝 기반 이미지 생성 분야에서 중요한 발전을 나타내며, 향후 연구와 응용에 있어 중요한 기초를 제공한다.
구현
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# 필요한 라이브러리들을 import
import argparse
import os
import numpy as np
import math
import itertools
import matplotlib.pyplot as plt
import cv2
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images/static/", exist_ok=True)
os.makedirs("images/varying_label/", exist_ok=True)
os.makedirs("images/varying_c1/", exist_ok=True)
os.makedirs("images/varying_c2/", exist_ok=True)
# 가중치를 정규 분포로 초기화하는 함수 정의
def weights_init_normal(m):
classname = m.__class__.__name__
# Conv 레이어의 경우, 가중치를 평균 0, 표준편차 0.02의 정규 분포로 초기화
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
# BatchNorm 레이어의 경우, 가중치는 평균 1, 표준편차 0.02의 정규 분포로 초기화하고, 편향은 0으로 초기화
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
# 범주형 데이터를 원-핫 인코딩으로 변환하는 함수 정의
def to_categorical(y, num_columns):
"""Returns one-hot encoded Variable"""
y_cat = np.zeros((y.shape[0], num_columns))
y_cat[range(y.shape[0]), y] = 1.0
return Variable(FloatTensor(y_cat))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Option():
n_epochs = 200 # 학습을 수행할 에폭의 수
batch_size = 64 # 한 번에 학습할 배치의 크기
lr = 0.0002 # Adam 최적화기의 학습률
b1 = 0.5 # Adam 최적화기의 첫 번째 모멘텀 감쇠율
b2 = 0.999 # Adam 최적화기의 두 번째 모멘텀 감쇠율
n_cpu = 8 # 배치 생성 시 사용할 CPU 스레드의 수
latent_dim = 62 # 잠재 공간의 차원 수
code_dim = 2 # 잠재 코드의 차원 수
n_classes = 10 # 데이터셋의 클래스 수
img_size = 32 # 각 이미지 차원의 크기
channels = 1 # 이미지 채널의 수
sample_interval = 500 # 이미지 샘플링 간격
opt = Option() # Option 클래스의 인스턴스 생성
# CUDA가 사용 가능한 경우 사용하도록 설정
cuda = True if torch.cuda.is_available() else False
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 잠재 벡터, 레이블, 코드의 차원을 합한 것이 입력 차원
input_dim = opt.latent_dim + opt.n_classes + opt.code_dim
# 업샘플링 전의 초기 크기
self.init_size = opt.img_size // 4 # 8
# 첫 번째 선형 레이어
self.l1 = nn.Sequential(nn.Linear(input_dim, 128 * self.init_size ** 2)) # 128 * 8 * 8
# 컨볼루션 블록: 업샘플링과 컨볼루션 레이어를 포함
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128), # 배치 정규화
nn.Upsample(scale_factor=2), # 이미지 크기 2배로 업샘플링
nn.Conv2d(128, 128, 3, stride=1, padding=1), # 128 채널 컨볼루션 레이어
nn.BatchNorm2d(128, 0.8), # 배치 정규화
nn.LeakyReLU(0.2, inplace=True), # LeakyReLU 활성화 함수
nn.Upsample(scale_factor=2), # 다시 이미지 크기 2배로 업샘플링
nn.Conv2d(128, 64, 3, stride=1, padding=1), # 64 채널 컨볼루션 레이어
nn.BatchNorm2d(64, 0.8), # 배치 정규화
nn.LeakyReLU(0.2, inplace=True), # LeakyReLU 활성화 함수
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), # 최종 채널 수로 컨볼루션
nn.Tanh(), # 출력을 [-1, 1] 범위로 조정하는 Tanh 활성화 함수
)
def forward(self, noise, labels, code):
gen_input = torch.cat((noise, labels, code), -1) # 잠재 벡터, 레이블, 코드를 결합
out = self.l1(gen_input) # 첫 번째 선형 레이어에 입력
out = out.view(out.shape[0], 128, self.init_size, self.init_size) # 배치, 128, 8, 8
img = self.conv_blocks(out) # 컨볼루션 블록을 통과하여 최종 이미지 생성
return img # 생성된 이미지 반환
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
# 판별자의 각 블록을 정의: 컨볼루션, LeakyReLU, 드롭아웃, 선택적 배치 정규화 포함
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
# 컨볼루션 블록: 다양한 컨볼루션 레이어와 배치 정규화, 활성화 함수, 드롭아웃을 포함
self.conv_blocks = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# 다운샘플링된 이미지의 크기
ds_size = opt.img_size // 2 ** 4 # 2
# 최종 출력 레이어: 진짜/가짜 판별, 클래스 레이블 예측, 연속적인 코드 추출
# 진짜/가짜 판별을 위한 레이어
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1)) # input: 128 * 4, output: 1
# 클래스 레이블 예측 레이어 (input: 128 * 4, output: n_classes)
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())
# 연속적인 코드 추출 레이어
self.latent_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.code_dim))
def forward(self, img):
out = self.conv_blocks(img) # 컨볼루션 블록 통과
out = out.view(out.shape[0], -1) # 적절한 형태로 변형
validity = self.adv_layer(out) # 진짜/가짜 판별
label = self.aux_layer(out) # 클래스 레이블 예측
latent_code = self.latent_layer(out) # 연속적인 코드 추출
return validity, label, latent_code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# 손실 함수 정의
adversarial_loss = torch.nn.MSELoss() # 생성자와 판별자 간의 경쟁을 평가하는데 사용됩니다. 진짜와 가짜를 구분하는 데 사용.
categorical_loss = torch.nn.CrossEntropyLoss() # 레이블에 대한 손실을 계산합니다. 분류 문제에 사용됩니다.
continuous_loss = torch.nn.MSELoss() # 연속적인 값에 대한 손실을 계산합니다. 생성된 이미지의 연속적 특징을 조절하는 데 사용.
# 손실 가중치 설정
lambda_cat = 1 # 범주형 손실의 가중치
lambda_con = 0.1 # 연속적 손실의 가중치
# 생성자와 판별자 초기화
generator = Generator()
discriminator = Discriminator()
# CUDA 사용 가능 여부에 따라 모델과 손실 함수를 GPU로 옮깁니다.
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
categorical_loss.cuda()
continuous_loss.cuda()
# 모델의 가중치 초기화
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
# 데이터 로더 설정
os.makedirs("dataset/mnist", exist_ok=True) # 데이터셋 저장 폴더 생성 (없는 경우)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"dataset/mnist", # 데이터셋 경로
train=True, # 훈련 데이터셋 사용
download=False, # 데이터셋 다운로드 여부
transform=transforms.Compose([ # 이미지 전처리: 크기 조정, 텐서 변환, 정규화
transforms.Resize(opt.img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]),
),
batch_size=opt.batch_size, # 배치 크기 설정
shuffle=True, # 데이터셋을 무작위로 섞음
)
# 최적화 함수 설정
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# 생성자와 판별자의 파라미터를 함께 최적화하는 최적화 함수
"""
itertools.chain 함수는 여러 이터러블(iterable) 객체들을 하나의 연속된 이터러블로 결합.
여기서는 생성자와 판별자의 파라미터를 하나의 이터러블로 결합하여 Adam 최적화 함수에 전달.
- 파라미터 최적화: itertools.chain(generator.parameters(), discriminator.parameters())를 사용해,
생성자와 판별자의 모든 파라미터를 하나의 그룹으로 묶고, 이 파라미터들을 최적화하기 위해 Adam 알고리즘을 적용.
- 정보 최적화: InfoGAN에서는 추가적으로 생성된 이미지와 잠재 코드 사이의 상호 정보를 최대화하는 최적화가 필요하다.
상호 정보 최대화는 주로 생성자의 손실 함수에 추가적인 항으로 포함되어, 생성자가 잠재 코드에 의미 있는 정보를 인코딩하도록 유도.
"""
optimizer_info = torch.optim.Adam(
itertools.chain(generator.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
# 텐서 타입 설정
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
정적 생성자 입력 설정
static_z
: 생성된 이미지의 다양성을 위해 사용되는 정적인 잠재 벡터.static_label
: 생성하고자 하는 이미지의 클래스를 결정하는 정적인 레이블.static_code
: 이미지의 특정 연속적인 특성을 제어하는 정적인 코드.
1
2
3
4
5
6
7
8
9
10
11
12
# 정적 생성자 입력을 위한 설정
# 모든 클래스에 대한 잠재 벡터(z) 초기화
static_z = Variable(FloatTensor(np.zeros((opt.n_classes ** 2, opt.latent_dim)))) # (100, 62)
# 정적 레이블 생성: 모든 클래스 조합에 대해 원-핫 인코딩을 적용
static_label = to_categorical(
np.array([num for _ in range(opt.n_classes) for num in range(opt.n_classes)]), num_columns=opt.n_classes
# (100, 10)
)
# 연속적인 코드(code) 초기화
static_code = Variable(FloatTensor(np.zeros((opt.n_classes ** 2, opt.code_dim)))) # (100, 2)
이미지 샘플링 및 저장 함수
- 이 함수는 정적인 잠재 벡터, 레이블, 코드를 사용하여 생성된 이미지를 저장. 또한, 연속적인 코드 값을 변화시키면서 생성된 이미지를 저장하여, 코드 값의 변화가 이미지에 어떤 영향을 미치는지 시각적으로 확인할 수 있다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
def sample_image(n_row, batches_done): """정적 샘플과 변화하는 코드를 사용하여 이미지 그리드 생성 및 저장""" # 정적 샘플 생성 및 저장 z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim)))) static_sample = generator(z, static_label, static_code) save_image(static_sample.data, "images/static/%d.png" % batches_done, nrow=n_row, normalize=True) # 연속적인 코드(c1, c2) 변화에 따른 이미지 생성 및 저장 zeros = np.zeros((n_row ** 2, 1)) c_varied = np.repeat(np.linspace(-1, 1, n_row)[:, np.newaxis], n_row, 0) c1 = Variable(FloatTensor(np.concatenate((c_varied, zeros), -1))) c2 = Variable(FloatTensor(np.concatenate((zeros, c_varied), -1))) sample1 = generator(static_z, static_label, c1) sample2 = generator(static_z, static_label, c2) save_image(sample1.data, "images/varying_c1/%d.png" % batches_done, nrow=n_row, normalize=True) save_image(sample2.data, "images/varying_c2/%d.png" % batches_done, nrow=n_row, normalize=True)
결과 시각화 함수
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def display_results(batches_done):
"""생성된 이미지의 그리드를 시각화"""
plt.figure(figsize = (15,5))
# 정적 샘플 이미지 시각화
plt.subplot(1,3,1)
img1 = cv2.imread("images/static/%d.png" % batches_done)
plt.imshow(img1, interpolation='nearest')
plt.xlabel('categorical code')
plt.ylabel('random z')
# 연속적인 코드 1에 따른 변화 이미지 시각화
plt.subplot(1,3,2)
img1 = cv2.imread("images/varying_c1/%d.png" % batches_done)
plt.imshow(img1, interpolation='nearest')
plt.xlabel('categorical code')
plt.ylabel('continuous code 1 [-1,1]')
# 연속적인 코드 2에 따른 변화 이미지 시각화
plt.subplot(1,3,3)
img1 = cv2.imread("images/varying_c2/%d.png" % batches_done)
plt.imshow(img1, interpolation='nearest')
plt.xlabel('categorical code')
plt.ylabel('continuous code 2 [-1,1]')
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
for epoch in range(opt.n_epochs): # 에포크 수만큼 반복
for i, (imgs, labels) in enumerate(dataloader): # 데이터 로더에서 배치를 순차적으로 가져옴
batch_size = imgs.shape[0] # 배치 크기 설정
# 진짜와 가짜를 판별할 때 사용할 레이블 생성
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
# 실제 이미지와 레이블을 적절한 텐서로 변환
real_imgs = Variable(imgs.type(FloatTensor))
labels = to_categorical(labels.numpy(), num_columns=opt.n_classes)
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad() # 생성자의 그래디언트를 0으로 초기화
# 잠재 공간에서 노이즈와 레이블을 샘플링
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
label_input = to_categorical(np.random.randint(0, opt.n_classes, batch_size),
num_columns=opt.n_classes)
code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))
# 배치의 이미지 생성
gen_imgs = generator(z, label_input, code_input)
# 생성자가 판별자를 속이는 데 사용할 손실 계산
validity, _, _ = discriminator(gen_imgs)
g_loss = adversarial_loss(validity, valid)
g_loss.backward() # 손실에 대한 역전파
optimizer_G.step() # 생성자의 가중치 업데이트
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad() # 판별자의 그래디언트를 0으로 초기화
# 실제 이미지와 가짜 이미지에 대한 손실 계산
real_pred, _, _ = discriminator(real_imgs)
d_real_loss = adversarial_loss(real_pred, valid)
fake_pred, _, _ = discriminator(gen_imgs.detach())
d_fake_loss = adversarial_loss(fake_pred, fake)
# 총 판별자 손실
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward() # 손실에 대한 역전파
optimizer_D.step() # 판별자의 가중치 업데이트
# ---------------------
# 정보 손실 최소화
# ---------------------
optimizer_info.zero_grad() # 정보 손실에 대한 최적화 함수 초기화
# 잠재 공간에서 새로운 노이즈, 레이블, 코드 샘플링
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
label_input = to_categorical(np.random.randint(0, opt.n_classes, batch_size),
num_columns=opt.n_classes)
code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))
gen_imgs = generator(z, label_input, code_input)
_, pred_label, pred_code = discriminator(gen_imgs)
# 정보 손실 계산 및 최적화
info_loss = lambda_cat * categorical_loss(pred_label, gt_labels) +
lambda_con * continuous_loss(pred_code, code_input)
info_loss.backward() # 손실에 대한 역전파
optimizer_info.step() # 정보 최적화 함수에 따른 가중치 업데이트
# ---------------------
# 진행 상황 로깅 및 이미지 샘플링
# ---------------------
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0: # 지정된 간격마다
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [info loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), info_loss.item())
)
sample_image(n_row=10, batches_done=batches_done) # 이미지 샘플링 및 저장
display_results(batches_done) # 생성된 이미지 시각화