VQGAN源码解析


  • 这里看的是一个外国小哥复现的简单版本的VQGAN

1. tqdm的一个用法

with tqdm(range(len(train_dataset))) as pbar:
    for i, imgs in zip(pbar, train_dataset):
        # do something
        pbar.set_postfix(
            VQ_Loss=np.round(vq_loss.cpu().detach().numpy().item(), 5),
            GAN_Loss=np.round(gan_loss.cpu().detach().numpy().item(), 3),
        )
        pbar.update(0)

2. 使用torchinfo可视化网络结构

from torchinfo import summary

encoder = Encoder()
summary(encoder, (1, 3, 256, 256))

3. VQGAN的Encoder部分

  • 模型结构为:
    VQGAN Encoder结构

  • 其中,

    • ResidualBlock是由两个GroupNorm+Swish+Conv组成;
    • DownSampleBlock是一个stride=2的Conv;
    • NonLocalBlock是一个Attention Block;
  • Encode之后会再接一个1x1的conv,称作是quant_conv. 然后过codebook,对应的还有个1x1的post_quant_conv。

# 简易Attention实现
attn = torch.bmm(q, k)
attn = attn * (int(c) ** (-0.5))
attn = F.softmax(attn, dim=2)
attn = attn.permute(0, 2, 1)
A = torch.bmm(v, attn)
A = A.reshape(b, c, h, w)

4. VQGAN的Codebook部分

  • 为什么要使用Codebook对Encoder得到的feature做离散化编码?
    • 使得模型学习到更抽象和压缩的数据表示
    • 强制信息瓶颈帮助提升生成模型的泛化能力
    • 提高计算效率
import torch
import torch.nn as nn

class Codebook(nn.Module):
    def __init__(self, args):
        super(Codebook, self).__init__()
        self.num_codebook_vectors = args.num_codebook_vectors  # 1024
        self.latent_dim = args.latent_dim  # 256
        self.beta = args.beta  # 0.25

        self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
        self.embedding.weight.data.uniform_(
            -1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors
        )

    def forward(self, z):
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.latent_dim)  # [1*16*16, 256]

        d = (
            torch.sum(z_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2 * (torch.matmul(z_flattened, self.embedding.weight.t()))
        )  # [256, 1024]

        min_encoding_indices = torch.argmin(d, dim=1)  # [256]
        z_q = self.embedding(min_encoding_indices).view(z.shape)

        loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
            (z_q - z.detach()) ** 2
        )

        # simply copy the gradient from the decoder to the encoder. 
        # 这里实现非常巧妙,也很容易出错。
        z_q = z + (z_q - z).detach()  # preserve the gradients for the backward flow.

        z_q = z_q.permute(0, 3, 1, 2)

        return z_q, min_encoding_indices, loss

5. VQGAN的Decoder部分

  • 模型结构为:
    VQGAN Decoder结构

6. VQGAN的Discriminator部分

  • 只有在超过一定的steps(10000)之后才开始加入Disc的训练

  • 模型结构为:
    VQGAN Discriminator结构

7. VQGAN的其他loss部分

  • perceptual_loss: 加载了一个pretrained VGG16模型,计算的real_img和gen_img之间的多层feature距离。
  • 重建loss
  • gan_loss:
perceptual_rec_loss = (
    args.perceptual_loss_factor * perceptual_loss
    + args.rec_loss_factor * rec_loss
)
perceptual_rec_loss = perceptual_rec_loss.mean()
g_loss = -torch.mean(disc_fake)  # 趋于0,或者大于0

λ = self.vqgan.calculate_lambda(perceptual_rec_loss, g_loss)
vq_loss = perceptual_rec_loss + q_loss + disc_factor * λ * g_loss

d_loss_real = torch.mean(F.relu(1.0 - disc_real))  # 大于1
d_loss_fake = torch.mean(F.relu(1.0 + disc_fake))  # 小于-1
gan_loss = disc_factor * 0.5 * (d_loss_real + d_loss_fake)

self.opt_vq.zero_grad()
vq_loss.backward(retain_graph=True)

self.opt_disc.zero_grad()
gan_loss.backward()

self.opt_vq.step()
self.opt_disc.step()

8. VQGAN的loss的calculate_lambda部分

  • 这个函数的作用: 动态调整模型训练中不同部分的损失贡献相关。
    • 根据最后一层的梯度情况来平衡感知损失和GAN损失的影响,有助于控制训练过程中的损失平衡。
def calculate_lambda(self, perceptual_loss, gan_loss):
    last_layer = self.decoder.model[-1]
    last_layer_weight = last_layer.weight
    perceptual_loss_grads = torch.autograd.grad(
        perceptual_loss, last_layer_weight, retain_graph=True
    )[0]
    gan_loss_grads = torch.autograd.grad(
        gan_loss, last_layer_weight, retain_graph=True
    )[0]

    λ = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
    λ = torch.clamp(λ, 0, 1e4).detach()
    return 0.8 * λ

到这里,我们训练VQGAN的第一步已经结束了,我们可以根据一组VQ来重建出一张图,那么接下就需要transformer来负责先预测出一组VQ然后重建出图来啦~


8. VQGAN的Transformer部分

  • AdamW优化器区分了bias, nn.LayerNorm, nn.Embedding以及pos_emb训练时都不加weight decay,而nn.Linear加了weight decay。
  • Transformer部分的网络结构主要是一个minGPT
def forward(self, idx, embeddings=None):
    # idx是输入图像通过vqgan的Encoder部分得到的indices经过随机mask掉一部分得到的indices.
    token_embeddings = self.tok_emb(idx)  # each index maps to a (learnable) vector

    t = token_embeddings.shape[1]
    position_embeddings = self.pos_emb[
        :, :t, :
    ]  # each position maps to a (learnable) vector
    x = self.drop(token_embeddings + position_embeddings)
    x = self.blocks(x)  # A vanilla multi-head masked self-attention layer
    x = self.ln_f(x)
    logits = self.head(x)  # [1, 1024]

    return logits, None
  • 这里训GPTTransformer的时候,

    def forward(self, x):
        _, indices = self.encode_to_z(x)  # 输入图像经过vqgan的Encoder部分得到的indices
    
        sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token
        sos_tokens = sos_tokens.long().to("cuda")  # 起始token的index,默认是0
    
        mask = torch.bernoulli(
            self.pkeep * torch.ones(indices.shape, device=indices.device)
        )
        mask = mask.round().to(dtype=torch.int64)
        random_indices = torch.randint_like(indices, self.transformer.config.vocab_size)
        # 随机mask掉一半的indices,用随机的indices代替
        new_indices = mask * indices + (1 - mask) * random_indices  
    
        new_indices = torch.cat((sos_tokens, new_indices), dim=1)
        target = indices
        logits, _ = self.transformer(new_indices[:, :-1])
    
        return logits, target  # 直接用cross entropy训练
  • 这样训完之后,GPT Transformer就可以根据前序的token来预测后续的token了

9. 训好了怎么sample生成样本呢?

@torch.no_grad()
def sample(self, x, c, steps, temperature=1.0, top_k=100):
    self.transformer.eval()
    x = torch.cat((c, x), dim=1)
    for k in range(steps):
        logits, _ = self.transformer(x)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            logits = self.top_k_logits(logits, top_k)

        probs = F.softmax(logits, dim=-1)
        ix = torch.multinomial(probs, num_samples=1)  # 按照概率随机取一个值
        x = torch.cat((x, ix), dim=1)

    x = x[:, c.shape[1]:]
    self.transformer.train()
    return x
  • 获得一个序列的indices送入到VQGAN的Decoder就可以生成图像了。

文章作者: David Chan
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 David Chan !
评论
  目录