AI・機械学習 データ分析実践 画像認識

【PyTorchで実装】ドラクエ画像でVQ-VAEを学習する

2022年1月27日

以前、以下の記事ではVQ-VAE(Vector-Quantized Variational Auto-Encoder)をTensorflowで実装しました。

今回は、TensorflowではなくPyTorchを使っている人も多いと思いますので、PyTorchでVQ-VAEを実装したいと思います

細かいモデルの説明は省略しますので、必要に応じて上記の記事をご参照ください。

なお、VQ-VAEで画像を生成するは、インプット画像を潜在変数に符号化して元の画像を復元するオートエンコーダ部分と、潜在変数の分布を学習して潜在変数を生成する部分にわかれますが、今回はオートエンコーダ部分だけを実装していきます

潜在変数を生成する部分はまた何かのモデル(論文ではPixelCNN)で実装したいと思います。

VQ-VAEの概要

まず、VQ-VAEについてざっくり解説します。

VQ-VAEは潜在変数を使った画像の生成モデルで、VAE(Variational Auto-Encoder)と似たような考え方です。

以下の図だと、犬の画像をエンコーダで潜在変数にマッピングし、デコーダで犬の画像を復元するように学習します。

画像に alt 属性が指定されていません。ファイル名: image.png

ではVQ-VAEの特徴はというと、潜在変数が離散的なベクトルで表されるという点です。

先に実装イメージを記載しておくと、以下のような処理になります(詳細は少しずつ説明します)。

画像に alt 属性が指定されていません。ファイル名: image-44.png
離散的なベクトルで表現とは

具体的には、エンコーダで\(z_e\)という潜在変数の一歩手前みたいなものを求めますが、そこから離散的なK種類の埋め込み表現(Embeddings)にマッピングします

以下がマッピングされた潜在変数ですが、各セルが\({\bf{e}}_1, {\bf{e}}_2, ...\)となっています。

画像に alt 属性が指定されていません。ファイル名: image-9.png

この\({\bf{e}}_1, {\bf{e}}_2, ...\)は、こちらのEmbedding Spaceと呼ばれる埋め込み表現の空間から一つ選んだものになっています。

画像に alt 属性が指定されていません。ファイル名: image-12.png

つまり、自由に値を取れるベクトルを埋め込み表現とするのではなく、各セルの値を\(K\)種類のベクトル(辞書みたいなもの)から一つ選んで設定する、というのがVQ-VAEの特徴です。

埋め込みベクトルの選び方

エンコーダで計算した\(z_e\)ともっとも距離が近い埋め込み表現\(e_j\)を選びます

数式で書くと以下のようになります。

$$z_q(x)=e_k, \hspace{10pt}\text{where } k=\arg\min_j\|z_e(x)-e_j\|_2$$

ここで、argminを取ると、その勾配が計算できないという問題が発生しますが、ここでは勾配を計算せず、その前のレイヤーから流れてきた勾配をそのまま流すというstraight estimatorという手法を使っています。

損失関数

損失関数は以下になります。

$$L=\log p\left(x|z_q(x)\right) + \left\|\text{sg}[z_e(x)]-e \right \|^2_2 + \beta \left\|z_e(x)-\text{sg}[e]\right\|^2_2$$

1項目は再構築誤差項になります(実装では二乗誤差を使っています)。

2項目は埋め込み表現を更新するための誤差項になります。

sgというのはStop Gradientの略で、勾配を計算しないということです。

3項目は、エンコーダーのアウトプット\(z_e\)が埋め込み表現\(e\)に対して、先にどんどん更新されないようにする項になります。

各項のより詳しい説明は『【論文解説+Tensorflowで実装】VQ-VAEを理解する』をご参照いただければと思います。

画像の生成について

以上はオートエンコーダ部分の仕組みでした。

これにより、画像が与えられたら、低次元の潜在変数にマッピングして、再度もとの画像を復元することができます

しかし、実際に新しい画像を生成するには、潜在変数を何等かのやり方で生成しなければなりません

VAEでは潜在変数に標準正規分布を仮定しているので、標準正規分布からランダムに潜在変数をサンプリングし、学習済みVAEのデコーダで、画像を生成することができました。

VQ-VAEでは潜在変数の分布がわからないので、それを知る必要があります。

そこで、論文ではPixelCNNという画像の生成モデルで潜在変数の分布を学習します

潜在変数の空間なので、もとの画像よりも低次元であり学習が簡単です。

そしてその学習したPixelCNNでランダムに潜在変数を生成し、学習済みのVQ-VAEのデコーダで画像を生成します

今回の記事では、オートエンコーダ部分の実装のみですので、PixelCNNを含めた画像生成はまたの機会にしたいと思います。

PixelCNN自体の解説実装はこちらでしています。

データについて

データは以下のようなスクレイピングしてきたドラクエ画像を使用します。

ただ、ちゃんとした画像だけだとサンプル数が少なかったため、Google検索でもう少し幅広に集めました。

以下のような画像です。

結構関係のない画像や学習にはいまいちな画像も含まれていたため、そこは目で見て落としていきました(除き切れていない、あまり良くない画像もありますね…)。

現状は学習データとして5000枚弱ぐらいあります。

もう少しスクレイピングしているので、将来的は1~2万枚ぐらいまでは増やしたいと思っています。

VQ-VAEの実装

では、まず必要なライブラリをインストールしていきましょう。

# PyTorch関連
import torch
import torch.nn as nn
import torch.nn.functional as F

# 描画など
import matplotlib.pyplot as plt
import numpy as np

モデル

次のような形で実装していきます。

インプット
  ↓
デコーダ
  ↓
pre_conv(潜在変数の一歩手前を計算する畳み込み層)
  ↓
Vector Quantizer(埋め込みベクトルへのマッピング)
  ↓
デコーダ
  ↓
アウトプット

という順です。

VQ-VAEの実装概要

では、順に作成していきましょう。

残差結合

まず、エンコーダ、デコーダでは残差結合を使いますので、先に残差結合層を作成しておきます。

class Residual(nn.Module):
  def __init__(self, in_channels, hidden_dim, num_residual_hiddens):
    super(Residual, self).__init__()
    self.conv_1 = nn.Conv2d(in_channels=in_channels,
                            out_channels=num_residual_hiddens,
                            kernel_size=3,
                            stride=1,
                            padding=1)
    self.conv_2 = nn.Conv2d(in_channels=num_residual_hiddens,
                            out_channels=hidden_dim,
                            kernel_size=1, 
                            stride=1)

  def forward(self, x):
    h = torch.relu(x)
    h = torch.relu(self.conv_1(h))
    h = self.conv_2(h)
    return x + h # 残差結合

そして、この残差結合を複数積み重ねますので、ResidualStackというクラスを作成します。

ここでは、nn.ModuleListを使って、num_residual_layers分だけ残差結合層をスタックします。

class ResidualStack(nn.Module):
  def __init__(self, in_channels, hidden_dim, num_residual_layers, num_residual_hiddens):
    super(ResidualStack, self).__init__()
    self._num_residual_layers = num_residual_layers
    # 複数の残差結合をスタックする
    self._layers = nn.ModuleList(
        [Residual(in_channels, hidden_dim, num_residual_hiddens)
                          for _ in range(self._num_residual_layers)]
                          )

  def forward(self, x):
    for i in range(self._num_residual_layers):
        x = self._layers[i](x)
    return F.relu(x)

エンコーダ

続いてエンコーダです。

エンコーダは以下のような仕組みになります。

エンコーダの仕組み

複数の畳み込み層と先ほど作成した残差結合のresidual stack層で構成されています。

上図のような順で計算し、最終的に潜在変数にマッピングするための\({\bf{z}}_e\)を計算します。

実装では以下のようになります。

class Encoder(nn.Module):
  def __init__(self, in_channels, hidden_dim, num_residual_layers, residual_hidden_dim, name=None):
    super(Encoder, self).__init__()
    self._in_channels = in_channels
    self._hidden_dim = hidden_dim
    self._num_residual_layers = num_residual_layers
    self._residual_hidden_dim = residual_hidden_dim

    self._enc_1 = nn.Conv2d(in_channels, hidden_dim // 2,
                            kernel_size=4,
                            stride=2,
                            padding=1                            
                            )
    self._enc_2 = nn.Conv2d(hidden_dim //2, hidden_dim,
                            kernel_size=4,
                            stride=2,
                            padding=1
                            )
    self._enc_3 = nn.Conv2d(hidden_dim, hidden_dim,
                            kernel_size=3,
                            stride=1,
                            padding=1
                            )
    self._residual_stack = ResidualStack(
        hidden_dim,
        hidden_dim,
        num_residual_layers,
        residual_hidden_dim,
        )
    
  def forward(self, inputs):
    h = torch.relu(self._enc_1(inputs))
    h = torch.relu(self._enc_2(h))
    h = self._enc_3(h) # ResidualStackの中にReLUが入っている
    return self._residual_stack(h)

デコーダ

次にデコーダです。

デコーダは潜在変数\({\bf{z}}_q\)をインプットとして、residual stack、畳み込み層と処理し、最終的に画像\({\bf{x}}\)を復元しようとします。

最後にシグモイド関数で0-1にしています(他のやり方の方が良いかもしれませんが手っ取り早いので…)。

デコーダの仕組み

実装はこちらです。

class Decoder(nn.Module):
  def __init__(self, in_channels, hidden_dim,  num_residual_layers, residual_hidden_dim,
               name=None):
    super(Decoder, self).__init__()
    self._in_channels = in_channels
    self._hidden_dim = hidden_dim
    self.num_residual_layers = num_residual_layers
    self.residual_hidden_dim = residual_hidden_dim

    self._dec1 = nn.Conv2d(in_channels, hidden_dim,
                           kernel_size=3,
                           stride=1,
                           padding=1)
    self._residual_stack = ResidualStack(hidden_dim, 
                                         hidden_dim,
                                         num_residual_layers,
                                         residual_hidden_dim
                                         )
    self._dec2 = nn.ConvTranspose2d(hidden_dim, hidden_dim // 2,
                                    kernel_size=4,
                                    stride=2,
                                    padding=1
    )
    self._dec3 = nn.ConvTranspose2d(hidden_dim // 2, 3,
                                    kernel_size=4,
                                    stride=2,
                                    padding=1
    )

  def forward(self, encoder_outputs):
    h = self._dec1(encoder_outputs)
    h = self._residual_stack(h)
    h = torch.relu(self._dec2(h))
    x_reconstructed = self._dec3(h)
    x_reconstructed = torch.sigmoid(x_reconstructed)
    return x_reconstructed

離散化

続いて、エンコーダで計算した\({\bf{z}}_e\)から離散化した潜在変数\({\bf{z}}_q\)をマッピングする部分です。

やっていることは以下の図のようにシンプルですが、埋め込みベクトルを選ぶ際には2乗距離が一番近いベクトルを選んでいます。

$$z_q(x)=e_k, \hspace{10pt}\text{where } k=\arg\min_j\|z_e(x)-e_j\|_2$$

潜在変数を離散化する仕組み

コードは以下です。

class VectorQuantizer(nn.Module):
  def __init__(self, embedding_dim, num_embeddings, commitment_cost,
               ):
    super(VectorQuantizer, self).__init__()
    self._embedding_dim = embedding_dim
    self._num_embeddings = num_embeddings
    self._commitment_cost = commitment_cost
    # コードブック(ボトルネック)
    self._w = nn.Embedding(self._num_embeddings, self._embedding_dim)
    self._w.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
    
  def forward(self, inputs):
    '''
    inputs: N×C×H×W
    '''
    # N×C×H×WをN×H×W×Cに変換する. (Cは埋め込みベクトルの次元)
    inputs = inputs.permute(0, 2, 3, 1).contiguous()
    input_shape = inputs.size()
    input_flattened = inputs.view(-1, self._embedding_dim) # すべて縦に並べる
    distances = (torch.sum(input_flattened ** 2, dim=1, keepdim=True) 
                    - 2 * torch.matmul(input_flattened, self._w.weight.t())
                    + torch.sum(self._w.weight ** 2, dim=1))
    encoding_indices = torch.argmax(-distances, 1).unsqueeze(1)
    # one-hotベクトルに変換
    encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
    encodings.scatter_(1, encoding_indices, 1) # one-hot
    # 埋め込み表現を取得し、元のインプットの形に戻す。
    quantized = torch.matmul(encodings, self._w.weight) # one-hot ⇒ ベクトル
    quantized = quantized.view(input_shape) 
    
    # 損失の計算
    # 二乗誤差で計算. sgの部分はdetach()で勾配を計算しないようにする
    e_latent_loss = F.mse_loss(quantized.detach(), inputs)
    q_latent_loss = F.mse_loss(quantized, inputs.detach())
    loss = q_latent_loss + self._commitment_cost * e_latent_loss
    
    # sgの部分はdetach()で勾配を計算しない
    quantized = inputs + (quantized - inputs).detach()
    quantized = quantized.permute(0, 3, 1, 2).contiguous()
    # perplexityを計算 
    avg_probs = torch.mean(encodings, dim=0)
    perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
    return {'distances': distances,
            'quantize': quantized,
            'loss': loss, 
            'encodings': encodings,
            'encoding_indices': encoding_indices,
            'perplexity': perplexity}

slef._wが埋め込みベクトルになります。

input_flattenedとしているところは、\(N\times H\times W\)のすべてのセルを縦に並べて、すべての\({\bf{e}}_j\)との距離を一気に計算しています(賢い計算方法ですね)。

そして、距離が最も近い埋め込みベクトルを選び、潜在変数\({\bf{z}}_q\)を計算しています。

最後に以下の式で損失関数の値を計算しています。

$$L=\log p\left(x|z_q(x)\right) + \left\|\text{sg}[z_e(x)]-e \right \|^2_2 + \beta \left\|z_e(x)-\text{sg}[e]\right\|^2_2$$

VQ-VAE全体

以下でエンコーダ、デコーダ、pre_vq_conv(エンコーダのアウトプットを処理する畳み込み層)、Vector Quantizerを合わせたVQ-VAE全体を作成します。

class VQVAE(nn.Module):
  def __init__(self, encoder, decoder, vqvae, pre_vq_conv1, 
               data_variance, name=None):
    super(VQVAE, self).__init__()
    self._encoder = encoder
    self._decoder = decoder
    self._vqvae = vqvae
    self._pre_vq_conv1 = pre_vq_conv1
    self._data_variance = data_variance
    
  def forward(self, inputs):
    z = self._pre_vq_conv1(self._encoder(inputs)) # zの事前分布
    vq_output = self._vqvae(z)
    x_reconstructed = self._decoder(vq_output['quantize'])
    reconstructed_error = torch.mean(torch.square(x_reconstructed - inputs) / self._data_variance)
    loss = reconstructed_error + vq_output['loss']
    return {
        'z': z,
        'x_reconstructed': x_reconstructed,
        'loss': loss,
        'reconstructed_error': reconstructed_error,
        'vq_output': vq_output
    }

モデルを作成

モデルを作成する関数を定義します。

エンコーダ、デコーダ、pre_vq_conv1、VectorQuantizerをそれぞれインスタンス化し、VQVAEクラスに渡すことでモデルを作成します。

def create_vqvae_model():
  encoder = Encoder(in_channels=3,
                    hidden_dim=HIDDEN_DIM,
                  num_residual_layers=NUM_RESIDUAL_LAYERS,
                  residual_hidden_dim=RESIDUAL_HIDDEN_DIM)

  decoder = Decoder(in_channels=EMBEDDING_DIM,
                    hidden_dim=HIDDEN_DIM,
                num_residual_layers=NUM_RESIDUAL_LAYERS,
                residual_hidden_dim=RESIDUAL_HIDDEN_DIM)

  pre_vq_conv1 = nn.Conv2d(HIDDEN_DIM, 
                           EMBEDDING_DIM,
                           kernel_size=1,
                           stride=1)
                             


  vqvae = VectorQuantizer( # EMAを使わない方
      embedding_dim=EMBEDDING_DIM,
      num_embeddings=NUM_EMBEDDINGS,
      commitment_cost=COMMITMENT_COST,
  )

  # モデル
  model = VQVAE(encoder=encoder,
                decoder=decoder,
                vqvae=vqvae,
                pre_vq_conv1=pre_vq_conv1,
                data_variance=train_data_variance)

  # オプティマイザ
  optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

  return model, optimizer

学習

では、ハイパーパラメータを設定します。

BATCH_SIZE = 32
IMAGE_SIZE = 32

HIDDEN_DIM = 128
RESIDUAL_HIDDEN_DIM = 32
NUM_RESIDUAL_LAYERS = 2

EMBEDDING_DIM = 32 # 各bottle-neckは64次元
NUM_EMBEDDINGS = 128 # いわゆるK. K種類のbottle-neckがある

COMMITMENT_COST = 0.25

USE_EMA = False

DECAY = 0.99
LEARNING_RATE = 3e-4 # 論文では少し小さめの2e-4

エポック数を200にして計算します。

損失等の値は必要に応じて表示してください。

model, optimizer = create_vqvae_model() # モデルの作成
train_res_recon_error, train_res_perplexity = [], []
val_res_recon_error, val_res_perplexity = [], []
device = 'cuda'
num_epochs = 200
model.to(device)
model.train()
for epoch in range(num_epochs):
  for i, data in enumerate(train_loader):
      data = data.to(device)
      results = model(data)

      optimizer.zero_grad()
      loss = results['loss']
      loss.backward()
      optimizer.step()

実行結果

以下のように学習が進んでいきます。

1エポック目終了時点ではモノクロ画像のようになっています。

1エポック終了時

次に10エポック目終了時です。

10エポック終了時

色もだいぶついてきて、どのモンスターかは十分わかります。

続いて、20エポック目終了時です。

20エポック終了時

もう元の画像とあまり変わらないレベルになってきました。

200エポック終了した時点では、非常にきれいな画像になりました。

200エポック終了時

ドラクエI~ドラクエVIIまでの画像をに学習データにして、ドラクエVIIIとドラクエIXの画像をテストデータにしていますが、上記のモンスターは学習データにもありそうです。

なので、ドラクエIXに初めて登場した画像も含めて見てみたところ、以下のようにうまくいっていることが確認できました(ドラクエVII以前の画像も混ざっています)。

複雑な画像もうまく再構築できていることがわかります。

まとめ

今回はPyTorchでVQ-VAEのオートエンコーダ部分を実装しました。

次は、PixelCNNなどの生成モデルで潜在変数の分布を学習して、VQ-VAEにより画像を生成するということをやってみたいと思います。

では!

-AI・機械学習, データ分析実践, 画像認識
-,