【論文解説】Diffusion Modelを理解する

AI・機械学習

GLIDESR3などといったモデルで使われており、最近よく見かけるようになったDiffusion Modelの解説をしたいと思います。

diffusion modelはもともと2015年に『Deep Unsupervised Learning using Nonequilibrium Thermodynamics』という論文で提案されています。

その後、今回紹介する『Denoising Diffusion Probabilistic Models』という論文が2020年に公開され、Denoising Diffusion Modelとして改良されました。

こちらのdiffusion modelはDenoising Diffusion Probabilistic ModelからDDPMと表記されることもあります。

ということで、今回はこちらの論文をメインに解説したいと思います。

Denoising Diffusion Probabilistic Models

理解が浅い部分があるかもしれないので、間違いがあればご指摘ください!!

また、以下の記事では、PyTorchを使ってDiffusionモデルを実装していますので、こちらもあわせて参考にしていただければと思います。

Twitterで新規記事などについて発信し始めたのでフォローしていただけると励みになります!!↓
フォローする

Diffusion Modelの概要

まずは、diffusion modelのざっくりとした概要について説明したいと思います。

forward processとreverse process

diffusion modelは、以下の図のように(1) forward processと(2) reverse processの2つの過程を考えます。

image  of diffusion process

forward processは画像にノイズを加えていって、最終的にはノイズだけになる確率過程です。

一方のreverse processはforward processの逆で、ノイズから画像になっていく確率過程です。

つまり、画像にノイズを加えていって、最終的にノイズのみになる確率過程を考え、その逆をたどることでノイズから画像を生成することができる、というものです。

こちらはさらにざっくりイメージです。

image of process
forward processとreverse processのイメージ. 本来は多次元です.

上記の図で言うと、forward processは\({\bf{x}}_0\)にノイズを加えて\({\bf{x}}_1\)、\({\bf{x}}_2\)、…となり、最終的に\({\bf{x}}_T\)というノイズだけになります。

reverse processはその逆をたどり、ノイズ\({\bf{x}}_T\)から画像\({\bf{x}}_0\)を生成します。

\({\bf{x}}_T\)を平均ゼロ、分散1の標準正規分布とすることで、標準正規分布に従うノイズから画像を生成することができます。

こちらは、『Deep Unsupervised Learning using Nonequilibrium Thermodynamics』からの図ですが、一番上はforward processで\(t=0\)からノイズが加わって、\(t=T\)では完全なノイズになっています。

sample of diffusion process

2段目はreverse processで、\(t=T\)からスタートし、\(t=0\)で多少ノイズは残るものの上段の\(t=0\)の図を再現できています。

また、一番下はノイズに当たります。

VAEとの関連性

観測されるデータは\({\bf{x}}_0\)になり、それ以外の\({\bf{x}}_1, \cdots, {\bf{x}}_T\)は潜在変数と考えます。

ですので、diffusion modelはVAEなどと同じように潜在変数の存在する生成モデルということです。

image of latent variables and data

そのため、目的関数は基本的にはVAEと同様に下界の最大化です。

基本的にはと言ったのは、最終的には下界を非常に単純化して計算される損失関数を最小化していきます。

diffusion modelの仕組み

では、概要を把握したところで、詳細に入っていきたいと思います。

確率過程

forward process

まず、forward process(diffusion processとも言う)、つまり真の値にノイズが加わっていく過程は以下のマルコフ連鎖で定義されます。

$$\begin{align}
q_\theta({\bf{x}}_{1:T}|{\bf{x}}_{0}):&=\prod^T_{t=1}q({\bf{x}}_{t}|{\bf{x}}_{t-1}),\tag{1}\\
q({\bf{x}}_{t}|{\bf{x}}_{t-1})&=N({\bf{x}}_{t}; \sqrt{1-\beta_t }{\bf{x}}_{t-1}, \beta_t {\bf{I}}) \tag{2}
\end{align}$$

上記の式は以下の2点を表しています。

  1. 独立性もしくはマルコフ性(上段の式)
    \({\bf{x}}_{0}\)というデータに対して、\({\bf{x}}_{1}, {\bf{x}}_{2}, \cdots\)という潜在変数が生成されますが、\({\bf{x}}_{t}\)の分布は\({\bf{x}}_{t-1}\)の値にのみ依存します。
  2. 正規分布に従う(下段の式)
    \(q({\bf{x}}_{t}|{\bf{x}}_{t-1})\)は平均\(\sqrt{1-\beta_t {\bf{x}}_{t-1}}\)、分散\(\beta_t {\bf{I}}\)の正規分布に従うということです。

ここで、\(\beta_t\)は各ステップごとに違ってきて、学習パラメータとすることも可能ですが、実験によると定数として決め打ちしてしまっても問題ないことがわかったとのことです。

ですので、forward processについては、学習するパラメータはありません。

reverse process

続いて、reverse processです。

reverseは以下のように定義します。

$$\begin{align}
p_\theta({\bf{x}}_{0:T}):&=p_\theta({\bf{x}}_{T})\prod^T_{t=1}p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t}), \tag{3}\\
p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})&=N\left({\bf{x}}_{t-1}; \mu_\theta({\bf{x}}_t, t), \Sigma_\theta({\bf{x}}_t, t)\right) \tag{4}
\end{align}$$

forward processと同様に、\({\bf{x}}_{t-1}\)の分布は\({\bf{x}}_{t}\)の値にのみ依存し(マルコフ性)、\(p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})\)も正規分布に従うということになります。

VAEと同じように、この正規分布の平均・分散をニューラルネットワークで求めます(最終的には分散は固定値にし、平均を求める代わりに誤差を求める問題に置き換わります)。

forward processとreverse processの関係

ここで、forward process、reverse processともに正規分布に従うという仮定を置いていますが、それは良いのでしょうか?

実は、forward processの分散が小さい場合、つまり拡散が小さい場合、その逆の過程であるreverse processも同じ関数形を取ることがわかっています

ですので、\(\beta_t\)を小さくとると、reverse processも同じ正規分布となることが保証されます。

目的関数

目的関数は、対数尤度\(\log p_\theta({\bf{x}}_0)\)を最大化することを考えますが、VAE(Variational Auto-Encoder)で見たように、ニューラルネットワークを使った潜在変数が存在するような複雑なモデルで対数尤度を最大化するのは簡単ではありません。

そこで、VAEと同様に変分推論の考え方を利用し、変分下界(Variational Lower Bound)を最大化します。

下界をざっくり説明すると

潜在変数を\({\bf{z}}\)、観測データを\({\bf{x}}\)とした場合、下界\(L\)は以下で表されます。

$$\begin{align}
L:= \mathbb{E}_{q({{\bf{z}}|\bf{x}})}\left[\log \frac{p({\bf{x}}, {\bf{z}})} { q({\bf{z}}| {\bf{x}})} \right]
\end{align}$$

この下界は対数周辺尤度の下界になっています。

$$\begin{align}
\log p({\bf{x}}) \ge L
\end{align}$$

対数尤度は求まらないけど、対数尤度はこれよりは大きいという下界であれば求められ、その下界を最大化することで対数尤度を最大化しましょう、というものです。

下界の計算方法の詳細についてはここでは省略しますので、こちらのVAEの記事をご参照ください。

では、これをdiffusion processに当てはめましょう。

潜在変数は\({\bf{x}}_{1:T}\)、データは\({\bf{x}}_0\)なので、潜在変数とデータの同時分布は\({\bf{x}}_{0:T}\)で表されますので、

$$\begin{align}
\log p_\theta({\bf{x}}_0)&\ge \mathbb{E}_{q({{\bf{x}}_{1:T}|\bf{x}}_0)}\left[\log \frac{p_\theta({\bf{x}}_{0:T})} { q({\bf{x}}_{1:T}| {\bf{x}}_0)} \right]=:L
\end{align}$$

となります。

ここでは一つのデータ\({\bf{x}}_0\)に対する周辺対数尤度を考えており、最終的にはデータに関して期待値\(\mathbb{E}\)を取る必要があり、論文ではそのような表記になっていますのでご留意ください。

ここで、\((1)\)式と\((3)\)式を代入して変形すると、

$$\begin{align}
L&=\mathbb{E}_{q({{\bf{x}}_{1:T}|\bf{x}}_0)}\left[\log \frac{\prod_{t=1}^T p({\bf{x}}_T) p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})} { \prod_{t=1}^T q({\bf{x}}_{t}| {\bf{x}}_{t-1})} \right] \\
&=\mathbb{E}_{q({{\bf{x}}_{1:T}|\bf{x}}_0)}\left[\log p({\bf{x}}_T) + \sum_{t\ge 1}\log \frac{p_\theta({\bf{x}}_{t-1}|p_\theta({\bf{x}}_t))}{q_\theta({\bf{x}}_{t}|q_\theta({\bf{x}}_{t-1}))} \right] \tag{5}
\end{align}$$

と表されます。

単純化した目的関数

これが下界を表す式になるので、これを最大化すればいいのですが、ここからゴリゴリ計算をしたり、単純化することにより非常にすっきりとしたシンプルな式に変形します。

最終形はこのような式になります(計算過程についてはこれから説明します)。

損失関数の最終形

$$L_\text{simple}(\theta)=\mathbb{E}_{{\bf{t}}, {\bf{x}}_0, {\bf{\epsilon}}}\left[\| \epsilon-\epsilon_\theta \left(\sqrt{\bar{\alpha}}{\bf{x}}_0
+\sqrt{1-\bar{\alpha}_t}\epsilon, t \right) \|^2\right] \tag{6}$$

非常にシンプルな式になっていますね。

なお、この式はデータに関して平均を取った形の最終形です。

\(\epsilon_\theta\)がニューラルネットワークになっています。

ざっくりとした解釈としては、各タイムステップ\(t\)、データ\({\bf{x}}_0\)に対してノイズである乱数\(\epsilon\)を振り、ノイズを加えた画像データとタイムステップ\(t\)をインプットとして、\(\epsilon\)を再構築するようなニューラルネットワークのパラメータ\(\theta\)を学習していきます。

アルゴリズム

イメージをつかむために先に学習アルゴリズムを軽く見てみましょう。

\({\bf{x}}_0\)をデータから取ってきて、タイムステップ\(t\)を1から\(T\)から一つ選びます。

そして、標準正規分布に従う\(\epsilon\)を振り、ニューラルネットワークで\(\epsilon_\theta\)を求め、\((6)\)式の勾配を計算してパラメータを更新していきます。

学習が終了しパラメータ\(\theta\)が決まったあとにサンプリングする際には、reverse processを使います。

アルゴリズムは以下のように、\({\bf{x}}_T~N({\bf{0}}, {\bf{I}})\)で生成し、\(T\)から1まで学習したreverse processで\({\bf{x}}_0\)を求めていきます。

目的関数の単純化

では、ここからは一つのデータに対する下界である\((5)\)式から単純化した損失関数\((6)\)式を導出していきたいと思います。

まず、\(q({\bf{x}}_t|{\bf{x}}_{0})\)を求めていきましょう。

これがわかれば、データ\({\bf{x}}_0\)が与えられた場合の\({\bf{x}}_t\)の分布がわかります。

\({\bf{x}}_{t-1}\)が与えられた場合、forward process \((2)\)式から、

$${\bf{x}}_t=\sqrt{1-\beta_t}{\bf{x}}_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} , \hspace{10pt} \epsilon~N({\bf{0}},{\bf{I}})$$

と表されます。

これを繰り返すと、

$$\begin{align}
{\bf{x}}_t&=\sqrt{1-\beta_t}{\bf{x}}_{t-1}+\sqrt{\beta_{t}}\epsilon_{t} \\
&=\sqrt{1-\beta_t}\left(\sqrt{1-\beta_{t-1}}{\bf{x}}_{t-2}+\sqrt{\beta_{t-1}}\epsilon_{t-1}\right)+\sqrt{\beta_{t}}\epsilon_{t}\\
&=\sqrt{1-\beta_t}\sqrt{1-\beta_{t-1}}{\bf{x}}_{t-2}+\sqrt{1-\beta_t}\sqrt{\beta_{t-1}}\epsilon_{t-1}+\sqrt{\beta_{t}}\epsilon_{t}\\
&=…\\
&={\bf{x}}_0\prod_{i=1}^t\sqrt{1-\beta_i}+\sum^t_{i=1}\left(\sqrt{\beta_i}
\prod_{j=i+1}^{t}\sqrt{1-\beta_j}\right)\epsilon_j
\end{align}$$

と表されます。

ここから、平均\(\mu_t\)は、

$$\begin{align}
\mu_t=\prod_{i=1}^t \sqrt{1-\beta_i}{\bf{x}}_0
\end{align}$$

となります。

$$\begin{align}
\alpha_t&=1-\beta_t,\\
\bar{\alpha}_t&=\prod_{s=1}^t\alpha_s
\end{align}$$

とすると、

$$\begin{align}
\mu_t&=\prod_{i=1}^t \sqrt{\alpha_i}=\sqrt{\prod_{i=1}^t\alpha_i}=\sqrt{\bar{\alpha}_t}{\bf{x}}_0
\end{align}$$

となります。

分散については、

$$\begin{align}
\Sigma_t&=\sum_{i=1}^t\left(\sqrt{\beta_i}\prod_{j=i+1}^t\sqrt{1-\beta_j}\right)^2\\
&=\sum_{i=1}^t\left(\beta_i\prod_{j=i+1}^t(1-\beta_j)\right)\\
&=\sum_{i=1}^t\left((1-\alpha_i)\prod_{j=i+1}^t\alpha_j\right)\\
&=\sum_{i=1}^t\left(\prod^t_{j=i+1}\alpha_{j}-\prod^t_{j=i}\alpha_j \right)\\
&=\sum_{i=1}^t\left(\alpha_{i+1, t}-\alpha_{i,t}\right)\\
&=\alpha_{t+1,t}-\alpha_{1,t}\\
&=1-\prod^t_{i=1}\alpha_i\\
&=1-\bar{\alpha}_t
\end{align}$$

と表すことができます。

変形の途中では、

$$\alpha_{i,t}=\prod^t_{j=1}\alpha_j$$

としています。

以上より、

$$q({\bf{x}}_t|{\bf{x}}_0)=N\left({\bf{x}}_t; \sqrt{\bar{\alpha}_t}{\bf{x}}_0, (1-\bar{\alpha}_t){\bf{I}}\right) \tag{7}$$

が求まります。

この式により、データ\({\bf{x}}_0\)に対し、ステップ\(t\)における\({\bf{x}}_t\)を求めることができます。

この式はあとで出てきます。

次に、目的関数\((5)\)式を変形していきます。

\((5)\)式で\(t=0\)のときだけ取り出すと

$$\begin{align}
\text{(5)式} &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log p({\bf{x}}_T)+\sum_{t>1}\log\frac{p_\theta({\bf{x}}_{t-1}|{\bf{x}}_t)} {q({\bf{x}}_{t}|{\bf{x}}_{t-1})}
+ \log\frac{p_\theta({\bf{x}}_{0}|{\bf{x}}_1)} {q({\bf{x}}_{1}|{\bf{x}}_{0})} \right]
\end{align}$$

となります。

ここで、

$$\begin{align}
q({\bf{x}}_{t}|{\bf{x}}_{t-1})&=q({\bf{x}}_{t}|{\bf{x}}_{t-1}, {\bf{x}}_{0})\\
&=\frac{q({\bf{x}}_{t}, {\bf{x}}_{t-1}|{\bf{x}}_{0})}{q({\bf{x}}_{t-1}|{\bf{x}}_{0})}\\
&=q({\bf{x}}_{t-1}|{\bf{x}}_{t},{\bf{x}}_{0})\frac{q({\bf{x}}_{t}|{\bf{x}}_{0})}{q({\bf{x}}_{t-1}|{\bf{x}}_{0})}
\end{align}$$

という関係を使って変形すると、

$$\begin{align}
\text{(5)式} &= \mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log p({\bf{x}}_T)+
\sum_{t>1}\log\left(\frac{p_\theta({\bf{x}}_{t-1}|{\bf{x}}_t)} {q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})}
\frac{q({\bf{x}}_{t-1}|{\bf{x}}_{0})}{q({\bf{x}}_{t}|{\bf{x}}_{0})}\right)
+ \log\frac{p_\theta({\bf{x}}_{0}|{\bf{x}}_1)} {q({\bf{x}}_{1}|{\bf{x}}_{0})} \right] \\
&=\mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log p({\bf{x}}_T)+
\sum_{t>1}\log\frac{p_\theta({\bf{x}}_{t-1}|{\bf{x}}_t)} {q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})}
+\sum^T_{t>1} \left(\log q({\bf{x}}_{t-1}|{\bf{x}}_{0}) -\log q({\bf{x}}_{t}|{\bf{x}}_{0})\right)
+ \log\frac{p_\theta({\bf{x}}_{0}|{\bf{x}}_1)} {q({\bf{x}}_{1}|{\bf{x}}_{0})} \right] \\
&=\mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log\frac{ p({\bf{x}}_T)}{q({\bf{x}}_T|{\bf{x}}_0)}+
\sum_{t>1}\log\frac{p_\theta({\bf{x}}_{t-1}|{\bf{x}}_t)} {q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})}
+ \log p_\theta({\bf{x}}_0|{\bf{x}}_1) \right] \\
&=D_\text{KL}\left(q({\bf{x}}_T|{\bf{x}}_0)p({\bf{x}}_T)\right)
+\sum_{t>1}D_\text{KL}\left(q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})\| p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{0})\right) \\
&\hspace{2pt}+\mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log p_\theta({\bf{x}}_0|{\bf{x}}_1)\right] \tag{8}
\end{align}$$

と変形できます。

ここで、3行目は

$$\sum^T_{t>1} \left(\log q({\bf{x}}_{t-1}|{\bf{x}}_{0}) -\log q({\bf{x}}_{t}|{\bf{x}}_{0})\right)-\log q({\bf{x}}_1|{\bf{x}}_0)=-\log q({\bf{x}}_T|{\bf{x}}_0)$$

となることを使っています。

なお、カルバック・ライブラー・ダイバージェンスは以下で表されますのでご参考まで。

$$D_\text{KL}(q|p)=\mathbb{E}_q\left[\log\frac{p(X)}{q(X)}\right]$$

そして、論文では上記の3つの項に分けています。

$$\begin{align}
L_T&=D_\text{KL}\left(q({\bf{x}}_T|{\bf{x}}_0)\|p({\bf{x}}_T)\right)\\
L_{t-1}&=D_\text{KL}\left(q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})\| p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})\right) \\
L_0&=\mathbb{E}_{q({\bf{x}}_{1:T}|{\bf{x}}_0)}\left[\log p_\theta({\bf{x}}_0|{\bf{x}}_1)\right]
\end{align}$$

こう置くことにより、

$$(5)式=L_T+\sum_{t>1}L_{t-1}+L_0 \tag{9}$$

と表されます。

\(L_T\)については、\(\theta\)が出てきていないのでニューラルネットワークは関係ありません。

そして、\(\beta_t\)を固定値とすると、パラメータがないので\(L_T\)については無視してよいとことになります

ですので、\(L_{1:T-1}\)と\(L_0\)について見ていきましょう。

\(L_{1:T-1}\)

上述の通り、\(L_{t-1}\)は以下のように表されます。

$$L_{t-1}=D_\text{KL}\left(q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})\| p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})\right) $$

\(q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})\)と\(p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})\)のカルバック・ライブラー・ダイバージェンスです。

1st step

ここで、\(p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})\)はreverse processの確率過程で、

$$p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})=N\left({\bf{x}}_{t-1}; \mu_\theta({\bf{x}}_t, t), \Sigma_\theta({\bf{x}}_t, t)\right)$$

と定義されていました。

\(\mu_\theta\)と\(\Sigma_\theta\)はニューラルネットワークで求める必要がありますが、ここで\(\Sigma_\theta\)をニューラルネットワークで計算しない\(\sigma_t^2{\bf{I}}=\beta_t{\bf{I}}\)として、簡略化します。

つまり、forward processと同じ分散としています。

実験の結果では、この仮定をおいてもパフォーマンスへの影響はあまりなかったようです。

2nd step

つづいて、あとで使うために\(q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})\)を計算しておきます。

こちらは以下のように計算できます。

$$\begin{align}
q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})&=N\left({\bf{x}}_{t-1}; \tilde{\mu}_t({\bf{x}}_t, {\bf{x}}_0), \tilde{\beta}_t{\bf{I}}\right)
\end{align}$$

ここで、\(\tilde{\mu}_t\)と\(\tilde{\beta}_t\)は次で表されます。

$$\begin{align}
\tilde{\mu}_t({\bf{x}}_t, {\bf{x}}_0)&=\frac{\sqrt{\alpha_{t-1}\beta_t}}{1-\bar{\alpha}_t}{\bf{x}}_0+\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}{\bf{x}}_t \tag{10}\\
\tilde{\beta}_t &= \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t \tag{11}
\end{align}$$

ちょっと大変そうなのでまだ計算していませんが、どこかのタイミングで確認して投稿したいと思います。

3rd step

この2つの式を使って、\(L_{t-1}\)を変形します。

まず、\(p(x)~N(\mu_p, \sigma^2_p)\)、\(p(q)~N(\mu_q, \sigma^2_q)\)の場合にその2つの分布のKLダイバージェンスは、

$$\begin{align}
D_\text{KL}\left(q\|p\right)&=\int p(x)\log\frac{p(x)}{q(x)}dx \\
&=\log \frac{\sigma_q}{\sigma_p} +\frac{\sigma_p^2+(\mu_p-\mu_q)^2}{2\sigma_q^2}-\frac{1}{2} \\
&= \frac{(\mu_p-\mu_q)^2}{2\sigma_q^2}+C \tag{12}
\end{align}$$

と計算できることを使います。

\(C\)は\(\mu_p\)を含まない項がまとめられています。

この計算はあとで載せておきますが(初回の投稿では載せていないかもしれません)、頑張って計算すれば求まります。

そして、それぞれ\(\mu_p=\mu_\theta({\bf{x}}_t, t)\)、\(\sigma_p^2=\beta_t\)、\(\mu_q=\tilde{\mu}({\bf{x}}_t, {\bf{x}}_0)\)、\(\sigma_q^2=\tilde{\beta}_t\)なので、

$$\begin{align}
L_{t-1}&=D_\text{KL}\left(q({\bf{x}}_{t-1}|{\bf{x}}_{t}, {\bf{x}}_{0})\| p_\theta({\bf{x}}_{t-1}|{\bf{x}}_{t})\right) \\
&= \frac{1}{2\sigma^2_t}\| \tilde{\mu}({\bf{x}}_t, {\bf{x}}_0) – \mu_\theta({\bf{x}}_t, t)\|^2 + C \tag{13}
\end{align}$$

となります。

上記のカルバック・ライブラー・ダイバージェンスは、二つの正規分布の平均パラメータの2乗誤差に比例することがわかります。

すっきりしてきましたね。

4th step

上記の\(L_{t-1}\)を使って計算しても良いのですが、さらに単純化を行います。

前に求めた\((7)\)式は、

$$q({\bf{x}}_t|{\bf{x}}_0)=N\left(\sqrt{\bar{\alpha}_t}{\bf{x}}_0, (1-\bar{\alpha}_t){\bf{I}}\right)$$

でした。

ここから、

$${\bf{x}}_t=\sqrt{\bar{\alpha}_t}{\bf{x}}_0 + \sqrt{1-\bar{\alpha}_t}\epsilon$$

と表されます。

なので、\((13)\)式の\({\bf{x}}_0\)を\({\bf{x}}_t\)で表すようにすると、

$$L_{t-1}-C= \frac{1}{2\sigma^2_t}\left\| \tilde{\mu}\left({\bf{x}}_t, \frac{1}{\sqrt{\bar{\alpha}_t}}\left({\bf{x}}_t({\bf{x}}_0, \epsilon)-\sqrt{1-\bar{\alpha}_t}\epsilon\right)\right) – \mu_\theta({\bf{x}}_t, t)\right\|^2 $$

となります。

\({\bf{x}}_t\)は\({\bf{x}}_0\)と\(\epsilon\)の関数になっています。

ここで、\(\tilde{\mu}\)に\((10)\)式で表されるので、\(10\)式を使って変形すると、

$$L_{t-1}-C= \frac{1}{2\sigma^2_t}\left\| \frac{1}{\sqrt{\alpha_t}}\left({\bf{x}}_t({\bf{x}}_0, \epsilon)-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon\right)- \mu_\theta({\bf{x}}_t, t)\right\|^2 \tag{14}$$

となります。

大変ではありますが、代入して計算するだけですので、途中式は記載しませんが(これもいつか掲載するかもしれません)、

$$\beta_t=1-\alpha_t$$

$$\bar{\alpha}_t=\prod_{s=1}^t\alpha_s \hspace{10pt} \Rightarrow \hspace{10pt} \frac{\bar{\alpha}_{t-1}}{\bar{\alpha}_{t}}=\frac{1}{\alpha_t}$$

といった関係を使えば求まります。

5th step

\((14)\)式より、\(\mu_\theta\)は\(\tilde{\mu}_t({\bf{x}}_t, \epsilon)\)を変形した

$$\frac{1}{\sqrt{\alpha_t}}\left({\bf{x}}_t({\bf{x}}_0, \epsilon)-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon\right)$$

を予測すればよいことになります。

ここから、関数\(\mu_\theta\)が\((15)\)に近くなるためには、\({\bf{x}}_t\)がインプットなので既知であることから、

$$\mu_\theta({\bf{x}}_t, t)=\frac{1}{\sqrt{\alpha_t}}\left({\bf{x}}_t({\bf{x}}_0, \epsilon)-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta({\bf{x}}_t, t))\right) \tag{16}$$

とすればよいと考えられます。

\(\epsilon_\theta({\bf{x}}_t, t)\)とすることで\(\epsilon\)を\({\bf{x}}_t\)と\(t\)をインプットとするパラメータが\(\theta\)の関数とします。

これで\(\mu_\theta\)が求まり、分散は\(\sigma^2_t\)と置いているので、reverse processの分布\(p({\bf{x}}_{t-1}|{\bf{x}}_{t})\)が求まりました。

ですので、\({\bf{x}}_t\)が与えられたときの\({\bf{x}}\)は

$${\bf{x}}_{t-1}=\mu_\theta({\bf{x}}_{t},t) + \sigma_t {\bf{z}}, \hspace{10pt} z\sim N({\bf{0}}, {\bf{I}})$$

でサンプリングすることができます。

そして、\((15)\)式を\((14)\)式に代入して整理すると、

$$\frac{\beta_t^2}{2\sigma^2_t\alpha_t(1-\bar{\alpha}_t)}\left\| \epsilon – \epsilon_\theta\left( \sqrt{\bar{\alpha}_t}{\bf{x}}_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t \right) \right\|^2 \tag{16}$$

と単純化することができます。

\((14)\)式は\(\tilde{\mu}\)をニューラルネットワークで予測する問題でしたが、\((16)\)式では\(\epsilon\)を予測する問題になりました。

どちらの方法でも良いのですが、あとで行う実験では\(\epsilon\)を予測する後者の問題を解く方が精度が良かったということです。

あとで、この式をもう少しだけ単純化し、それがさらに良い精度になったということです。

\(L_0\)

次に\(L_0\)を見ていきます。

\(L_0\)は

$$L_0=\mathbb{E}_{q({\bf{x}}_{1:T})|{\bf{x}}_0}\left[\log p_\theta({\bf{x}}_0|{\bf{x}}_1)\right]$$

です。

\({\bf{x}}_1\)から\({\bf{x}}_0\)を生成する、つまり潜在変数から画像を生成する部分に当たるのでデコーダの項になります。

デコーダですので、PixelCNNなどをデコーダとして使うことも一つですが、ここではシンプルに正規分布の確率計算にしています。

インプットの\(\{0, \cdots, 255\}\)を\([-1,1]\)にスケーリングしているものと仮定しておきます。

$$\begin{align}
\log p_\theta({\bf{x}}_0|{\bf{x}}_1)&=\prod^D_{i=1}\int^{\delta_+(x_0^i)}_{\delta_-(x_0^i)}N(x;\mu_\theta^i({\bf{x}}_1,1),\sigma_1^2)dx
\end{align}$$

ここで、

$$\begin{align}
\delta_+(x_0^i)&=\left\{ \begin{array}{l, l}\infty & \hspace{10pt}\text{if }x=1\\ x+\frac{1}{255} & \hspace{10pt}\text{if }x<1 \end{array}\right.,
\delta_-(x_0^i)&=\left\{ \begin{array}{l, l}-\infty & \hspace{10pt}\text{if }x=-1\\ x-\frac{1}{255} & \hspace{10pt}\text{if }x>-1 \end{array}\right.
\end{align}$$

です。

\(D\)は\({\bf{x}}\)の次元です。

\(t=1\)なので正規分布の平均は\(\mu_\theta({\bf{x}}_1, 1)\)になっています。

これは単に分布関数\(N\)を使って\((x-\frac{1}{255},x+\frac{1}{255})\)の区間で面積を計算し、(同時)確率を求めているだけです。

いずれにせよ、さらに単純化した最終形ではこちらは使いません。

さらに単純化

以下の\((16)\)式は、\(\epsilon\)をニューラルネットワークで予測した際の誤差をタイムステップに応じて重みづけしていることになります。

$$\frac{\beta_t^2}{2\sigma^2_t\alpha_t(1-\bar{\alpha}_t)}\left\| \epsilon – \epsilon_\theta\left( \sqrt{\bar{\alpha}_t}{\bf{x}}_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t \right) \right\|^2$$

論文では、重みづけをしない形にして、

$$\left\| \epsilon – \epsilon_\theta\left( \sqrt{\bar{\alpha}_t}{\bf{x}}_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t \right) \right\|^2$$

と単純化します。

実験の結果では、このように単純化した方が精度が改善していますので、あとで確認したいと思います。

そして、最終的な損失関数は、タイムステップ\(t\)、データ\({\bf{x}}_0\)、乱数\(\epsilon\)で平均を取ります。

$$L_\text{simple}(\theta)=\mathbb{E}_{t,{\bf{x}}_0, \epsilon}\left[\left\| \epsilon – \epsilon_\theta\left( \sqrt{\bar{\alpha}_t}{\bf{x}}_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t \right) \right\|^2\right] \tag{17}$$

以上で、この論文において最も精度が良かった損失関数を求めることができました。

\(t=1\)の場合が\(L_0\)に対応しますが、この単純化した損失関数では\(L_0\)の議論はまったく出てこず、\(t>1\)の場合と同様に計算します。

実験

タイムステップは\(T=1000\)とます。

ニューラルネットワークの部分はU-Netを使用します。

U-Netの論文はこちらです。

U-Net: Convolutional Networks for Biomedical Image Segmentation

U-Netの詳しい説明は割愛しますが、上図のようにインプットを小さいサイズに圧縮し戻していくモデルです。

その際に同じサイズ同士でスキップ・コネクションにより連結するという特徴があります。

U-Netについては、こちらの記事で実装していますので参考にしてみてください ⇒ 『DiffusionモデルをPyTorchで実装する② ~ U-Net編』

(CouseraのGANコースでも解説されています。 ⇒ 【Cousera講座レビュー】DeepLearning.ai『Generative Adversarial Networks(GANs)』)

さらにTransformerで使われていたsin関数を使ったposition embeddingを使用し、self-attentionを使ったモデルとします。

そして、\(\beta_1=10^{-4}\)、\(\beta_T=0.02\)としています。

この\(\beta\)を十分に小さくすることで、forward processとreverse processが同じ関数形で近似できます。

結果

では実験結果を見ていきましょう。

CIFAR10での精度

CIFAR10データセットを使った結果です。

\(L_\text{simple}\)を損失関数とした場合において、IS(Inceptionスコア)はStyleGANには及びませんが、FIDに関してはStyleGANを上回っており、非常に良い結果となっています

ちなみに、Inceptionスコアは高ければ良い指標、FIDは低ければ良い指標になります。

それぞれの詳細は『Inception Scoreを理解する』と『Frechet Inception Distanceを理解する』をご参照ください。

ラベルを指定して生成するConditionalと比較しても一番ではありませんが悪くない結果です。

生成された画像

以下はLSUNデータセットで学習して、生成された画像です。

非常にキレイですね。

損失関数の比較

続いて、損失関数を単純化する影響について見ていきます。

結果は以下のようになっています。

まず、\(\mu\)を予測した場合だと不安定だったため結果が記載されていません。

一方で\(\epsilon\)を予測した場合の一番単純化した式\((17)\)式を使った場合が一番精度が良くなっていることがわかります

他にも\(\Sigma\)を学習させて場合もありますが、学習させない場合の方が結果が良くなっていることがわかります。

画像の生成過程

次にreverse processにおいて生成される過程\({\bf{x}}_t\)を見てみましょう。

以下のように、初めはランダムなノイズですが(一番左)、タイムステップが進むごとにだんだんとぼんやりした全体像が見えてきて、最後には詳細が鮮明になってきています。

潜在変数の共有

次に、同じ\({\bf{x}}_t\)から生成される画像を比較します。

Share x1000とある一番左の4つの画像がありますが、これはその4つのうち右下の\({\bf{x}}_{1000}\)を共有して生成された画像です。

ですので、同じノイズから生成されているので、まったく違う画像が生成されています。

一方でShare x750とある左から二番目の画像は同じ\({\bf{x}}_{750}\)から生成された画像です。

若干違うものの、すべてサングラスをかけているということが共通しています

さらに500、250となるにつれ、同じような画像が生成されます。

潜在変数の補間

他にも以下のようにそれぞれのステップで潜在変数を画像を補間した場合に生成される画像を見ることができます。

ステップが1000に近いほどバラバラの画像になり、小さいステップで補間するほど具体的になり、あまり小さいステップだと無理やり補間してしまい、不自然な画像になってしまいます。

まとめ

今回はDiffusionモデルを見てきました。

まだ計算過程を記載しきれていない箇所や(解くのも大変ですが、式を書いていくのはさらに大変なので…)、確認しきれていないところがあるので、その辺りはおいおい記載していきたいと思います。

では、次はGLIDEやSR3などDiffusionモデルを見ていきたいと思います!!

mm_0824

システム開発会社や銀行・証券会社で統計や金融工学を使ったクオンツ・分析業務を長く担当してきました。

現在はコンサルティング会社のデータ・サイエンティストとして機械学習、自然言語処理技術を使ったモデル構築・データ分析を担当しています。

このサイトでは論文や本、Udemyなどの学習ツールについての情報発信をしていきます。

皆様の業務や勉強のお役に立てれば嬉しいです。

Twitterでも新規記事についての発信をし始めたので是非フォローしていただけると嬉しいです!↓

フォローする
AI・機械学習 画像認識
フォローする
楽しみながら理解するAI・機械学習入門

コメント

タイトルとURLをコピーしました