- 这里看的是一个外国小哥复现的简单版本的VQGAN
1. tqdm的一个用法
with tqdm(range(len(train_dataset))) as pbar:
for i, imgs in zip(pbar, train_dataset):
# do something
pbar.set_postfix(
VQ_Loss=np.round(vq_loss.cpu().detach().numpy().item(), 5),
GAN_Loss=np.round(gan_loss.cpu().detach().numpy().item(), 3),
)
pbar.update(0)
2. 使用torchinfo可视化网络结构
from torchinfo import summary
encoder = Encoder()
summary(encoder, (1, 3, 256, 256))
3. VQGAN的Encoder部分
模型结构为:

其中,
- ResidualBlock是由两个GroupNorm+Swish+Conv组成;
- DownSampleBlock是一个stride=2的Conv;
- NonLocalBlock是一个Attention Block;
Encode之后会再接一个1x1的conv,称作是quant_conv. 然后过codebook,对应的还有个1x1的post_quant_conv。
# 简易Attention实现
attn = torch.bmm(q, k)
attn = attn * (int(c) ** (-0.5))
attn = F.softmax(attn, dim=2)
attn = attn.permute(0, 2, 1)
A = torch.bmm(v, attn)
A = A.reshape(b, c, h, w)
4. VQGAN的Codebook部分
- 为什么要使用Codebook对Encoder得到的feature做离散化编码?
- 使得模型学习到更抽象和压缩的数据表示
- 强制信息瓶颈帮助提升生成模型的泛化能力
- 提高计算效率
import torch
import torch.nn as nn
class Codebook(nn.Module):
def __init__(self, args):
super(Codebook, self).__init__()
self.num_codebook_vectors = args.num_codebook_vectors # 1024
self.latent_dim = args.latent_dim # 256
self.beta = args.beta # 0.25
self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
self.embedding.weight.data.uniform_(
-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors
)
def forward(self, z):
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.latent_dim) # [1*16*16, 256]
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * (torch.matmul(z_flattened, self.embedding.weight.t()))
) # [256, 1024]
min_encoding_indices = torch.argmin(d, dim=1) # [256]
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# simply copy the gradient from the decoder to the encoder.
# 这里实现非常巧妙,也很容易出错。
z_q = z + (z_q - z).detach() # preserve the gradients for the backward flow.
z_q = z_q.permute(0, 3, 1, 2)
return z_q, min_encoding_indices, loss
5. VQGAN的Decoder部分
- 模型结构为:

6. VQGAN的Discriminator部分
只有在超过一定的steps(10000)之后才开始加入Disc的训练
模型结构为:

7. VQGAN的其他loss部分
- perceptual_loss: 加载了一个pretrained VGG16模型,计算的real_img和gen_img之间的多层feature距离。
- 重建loss
- gan_loss:
perceptual_rec_loss = (
args.perceptual_loss_factor * perceptual_loss
+ args.rec_loss_factor * rec_loss
)
perceptual_rec_loss = perceptual_rec_loss.mean()
g_loss = -torch.mean(disc_fake) # 趋于0,或者大于0
λ = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss)
vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss
d_loss_real = torch.mean(F.relu(1.0 - disc_real)) # 大于1
d_loss_fake = torch.mean(F.relu(1.0 + disc_fake)) # 小于-1
gan_loss = disc_factor * 0.5 * (d_loss_real + d_loss_fake)
self.opt_vq.zero_grad()
vq_loss.backward(retain_graph=True)
self.opt_disc.zero_grad()
gan_loss.backward()
self.opt_vq.step()
self.opt_disc.step()
8. VQGAN的loss的calculate_lambda部分
- 这个函数的作用: 动态调整模型训练中不同部分的损失贡献相关。
- 根据最后一层的梯度情况来平衡感知损失和GAN损失的影响,有助于控制训练过程中的损失平衡。
def calculate_lambda(self, perceptual_loss, gan_loss):
last_layer = self.decoder.model[-1]
last_layer_weight = last_layer.weight
perceptual_loss_grads = torch.autograd.grad(
perceptual_loss, last_layer_weight, retain_graph=True
)[0]
gan_loss_grads = torch.autograd.grad(
gan_loss, last_layer_weight, retain_graph=True
)[0]
λ = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
λ = torch.clamp(λ, 0, 1e4).detach()
return 0.8 * λ
到这里,我们训练VQGAN的第一步已经结束了,我们可以根据一组VQ来重建出一张图,那么接下就需要transformer来负责先预测出一组VQ然后重建出图来啦~
8. VQGAN的Transformer部分
- AdamW优化器区分了bias, nn.LayerNorm, nn.Embedding以及pos_emb训练时都不加weight decay,而nn.Linear加了weight decay。
- Transformer部分的网络结构主要是一个minGPT。
def forward(self, idx, embeddings=None):
# idx是输入图像通过vqgan的Encoder部分得到的indices经过随机mask掉一部分得到的indices.
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
t = token_embeddings.shape[1]
position_embeddings = self.pos_emb[
:, :t, :
] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x) # A vanilla multi-head masked self-attention layer
x = self.ln_f(x)
logits = self.head(x) # [1, 1024]
return logits, None
这里训GPTTransformer的时候,
def forward(self, x): _, indices = self.encode_to_z(x) # 输入图像经过vqgan的Encoder部分得到的indices sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token sos_tokens = sos_tokens.long().to("cuda") # 起始token的index,默认是0 mask = torch.bernoulli( self.pkeep * torch.ones(indices.shape, device=indices.device) ) mask = mask.round().to(dtype=torch.int64) random_indices = torch.randint_like(indices, self.transformer.config.vocab_size) # 随机mask掉一半的indices,用随机的indices代替 new_indices = mask * indices + (1 - mask) * random_indices new_indices = torch.cat((sos_tokens, new_indices), dim=1) target = indices logits, _ = self.transformer(new_indices[:, :-1]) return logits, target # 直接用cross entropy训练这样训完之后,GPT Transformer就可以根据前序的token来预测后续的token了
9. 训好了怎么sample生成样本呢?
@torch.no_grad()
def sample(self, x, c, steps, temperature=1.0, top_k=100):
self.transformer.eval()
x = torch.cat((c, x), dim=1)
for k in range(steps):
logits, _ = self.transformer(x)
logits = logits[:, -1, :] / temperature
if top_k is not None:
logits = self.top_k_logits(logits, top_k)
probs = F.softmax(logits, dim=-1)
ix = torch.multinomial(probs, num_samples=1) # 按照概率随机取一个值
x = torch.cat((x, ix), dim=1)
x = x[:, c.shape[1]:]
self.transformer.train()
return x
- 获得一个序列的indices送入到VQGAN的Decoder就可以生成图像了。