【論文解説】BARTを理解する

AI・機械学習

今回は、『BART(Bidirectional Auto-Regressive Transformer)』を解説したいと思います。

簡単に言うと、BARTはBERTをSequence-to-Sequence(Seq2Seq)の形にしたものです。

ですので、モデルの仕組みは当初のTransformer論文で提案された形に近くなっています。

「Attention is All You Need」より

このSeq2Seqの仕組みにより、機械翻訳(Machine Translation)や文書の要約(Document Summarization)にも適用することが可能です。

そして、RoBERTaと同じデータセットで学習することで、分類タスクの精度はRoBERTaと同程度、文章生成系のタスクでは過去のモデルをアウトパフォームするという結果が出ています。

では、詳細を見ていきましょう。

論文はこちらです。

BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension

BARTの仕組み

BARTの仕組みはBERTの双方向TransformerとGPTの片方向(Auto-regressive)Transformerを組み合わせた形です。

BERTをエンコーダとしてインプットを処理し、インプットをエンコーディングします。

そして、そのエンコーディングした情報を使ってGPTのauto-regressiveなデコーダにより文書を生成します。

まず、こちらがBERTの簡略図です。

双方向のTransformerで、単語をランダムにマスクし(下図のB, D)、それを予測して事前学習を行います(Masked Language Modelと呼ばれます)。

そして、こちらがGPTなどのAuto-regressiveなモデルです。

次の単語、次の単語を予測する形で事前学習を行います(Language Modelと呼ばれます)。

そして、BARTはそれらを組み合わてSeq2Seq(Sequence-to-Sequence)にしたもので、以下のように表されます。

この仕組みにより嬉しい点は以下です。

  • 翻訳や文書の要約などの文章生成がうまくいく。
  • マスク付き言語モデルの事前学習以外に、より汎用的な学習方法を使用できる。

1つめはSeq2Seqの形を取っているので明らかだと思います。

質問に対して、答えを生成するということも可能ですし、文章要約も可能になります。

2つめは、この論文で深く分析されている点ですが、複数の単語の列を一つの[MASK]に置き換え、任意の長さの文字を予測するといったより難しい事前学習も可能になります。

モデルの構造

モデルは『Attention is All You Need』で提案されたTransformerをほぼ同じなので、細かいところは触れられていません。

Transformerの論文と違うところは、活性化関数をReLUではなく、GPTやBERTのようにGeLUにしている点です。

GeLUの詳細はこちらをご覧ください。

エンコーダー、デコーダーは、ベースサイズでそれぞれ6層のTransformerブロックで構成され、ラージサイズではそれぞれ12層のTransformerブロックとなっています。

BARTの事前学習

BARTは文章を生成することが可能なので、インプットとアウトプットの長さを変えたりすることが可能です。

ですので、通常のBERTよりも事前学習の方法に自由度があります。

そこで、本論文ではいくつかの事前学習方法を試し、比較しています。

  • Token Masking
    BERTと同じで、ランダムに単語をマスキングし、その部分を予測するというものです。
  • Token Deletion
    ランダムに単語を削除し、その単語を埋めた文章を生成することを学習します。
    どの単語が抜けているかわからないので、それも予測しないとならず、Token Maskingと比べて難易度が上がると考えられます。
  • Text Infilling
    複数の単語の並びを一つの[MASK]で置き換えます。
    置き換える単語数は、\(\lambda=3\)のポアソン分布に従う乱数を発生させて決めます。ですので平均3個(分散も3)の連続した単語が[MASK]に置き換えられることになります。
  • Sentence Permutation
    複数の文章からなる文書について、文章の順番をシャッフルします。
  • Document Rotation
    文章から単語を一つ選び、その単語が一番初めになるように、文章を回転させます。
    「私 は 犬 を 飼って いる」であれば、「犬 を 飼って いる 私 は」というような具合です。
    どの単語が初めの単語か?を学習します。

ファインチューニング

ファインチューニングの考え方自体に特別な点はありませんが、エンコーダー・デコーダーの形なので少し特殊な処理が必要な部分もあります。

タスクに応じたインプットと分類・生成方法を見ていきましょう。

  • 文章分類タスク
    分類したい文章をエンコーダー、デコーダーの両方に入力します。
    分類に使う層はデコーダーの最終層の最後の単語部分になります。
    BERTでは文章の最初に分類用の[CLS]トークンを付けていましたが、BARTでは文章の最後に分類用のトークンを追加するとのことです。
  • 単語分類タスク
    SQuADのように、どの単語(列)が答えかを答えるタスクです。
    質問文をつなげたものをエンコーダー、デコーダー両方のインプットとします。そして、デコーダーの最終層の各単語部分の隠れ層の値を分類用に使います。
  • 文章生成タスク
    文章要約などのタスクです。
    エンコーダーには要約したい文章を、デコーダーは自分自身が生成した文章をインプットとしauto-regressiveに処理を行います。
  • 機械翻訳
    外国語を英語に翻訳するタスクを行います。
    この際にBARTでは、単純に外国語をエンコーダーのインプットとするだけではなく、エンコーダーの埋め込み層を別の層に変えます(source encoder)。
    この層が外国語を英語にマッピングしているようなイメージです。
    そしてまず、source encoderとpositional embedding層、self-attentionの初めに掛けるprojection行列のみを更新し、それ以外は更新しません。その後、すべてのパラメータを更新します。

学習方法・実験

前にマスクの方法を変えた事前学習方法を説明しましたが、論文では、さらに5つの学習方法を試しています。

  • Language Model
    こちらは通常の言語モデルの学習方法です。
    GPTのように文章を左から右に、次の単語次の単語を予測することで学習していきます。
  • Permutated Language Model
    XLNetのような形で文章中の単語の順番を変えます。
    ここでは、文章中の1/6の単語の順番を変えて、それを予測します。
  • Masked Language Model
    BERTの学習方法です。文章中の15%の単語を[MASK]に置き換え、その部分の単語を予測します。
  • Multitask Masked Language Model
    GPTのleft-to-rightのマスク、BERTのランダムなマスクに、さらにマスクを追加します。
    1/6のleft-to-right、1/6のright-to-left、1/3はマスクしない、残り1/3については、初めの50%(半分)はマスクせず、のこりをleft-to-rightのマスクをします。
  • Masked Seq-to-Seq
    文章中の単語の50%の範囲をマスクし、Seq-to-Seqでマスク部分を予測します。

タスク

  • 続いて、実際に解いていくタスクを説明します。
  • SQuAD
    Wikipediaの記事について、質問に対する答えの部分を抜き出す形で答えるタスクです。
    インプットは、質問文と記事を[SEP]で区切ったものです。
    これをエンコーダーとデコーダーの両方に入力し、デコーダーの各単語位置の(最後の)隠れ層の値を使って、分類層で開始位置、終了位置かどうかを分類します。
  • MNLI
    2つの文章の含意関係(含意、矛盾、中立)を予測します。
    2つの文章を[SEP]でつなぎ、最後に[EOS]を付加したものをエンコーダー、デコーダー両方のインプットとします。
    そして、デコーダーの[EOS]部分に対応する隠れ層の値を分類に使用します。
  • ELI5
    質問に対する答えを長文の文書をもとに答えます。
    データセットの例を見る限り非常に難しいタスクのようです。
    私には解けません(笑)

  • インプットは質問とドキュメントで、デコーダーが文章を生成します。
  • XSum
    ニュースの要約をするデータセットです。
  • ConvAI2
    対話文を生成するデータセットです。
    インプットは対話の内容ペルソナ(何が好きか、何をやっているかなど)です。
  • CNN/DM
    こちらも要約のためのデータセットです。

結果

では、上記のタスクについて、さまざまな事前学習の方法を試した結果を、論文の解説に沿って見ていきましょう。

まず、全体の結果は以下のようになっています。

すべてのデータセットに共通して良い手法は存在しない

表の上段ですが、ELI5ではLanguage Modelによる事前学習方法が一番良い結果となっていますが、MNLIデータセットでは精度が一番悪くなっています。

単語のマスキングは必要

表の下段を見ると、Document RotationやSentence Shufflingは他の方法と比べて精度が著しく悪化しています。

Token MaskingやToken Deletion、Text Infillingなど単語や単語列を[MASK]に置き換える手法の方が良くなっています。

やはり、単語の順番のみを予測するだけでは言語の学習にはならず、どんな単語が入るか?を予測することで学習できるということのようですね。

Left-to-RightのLanguage Modelの事前学習は文章生成タスクで有効。

SQuADやMNLIといったタスクではLanguage ModelよりもMasked Language ModelやPermutated Language Modelの方が精度が高くなっていますが、それ以外の文章生成するタスクではLanguage Modelの方がperplexityが低くなっています

SQuADには双方向のエンコーダーが必要。

一方で、Language ModelはSQuADの精度が非常に低くなっています。

ですので、質疑応答のような文章を最後まで読む必要があるようなタスクにおいては、双方向のattentionが必要なことがわかります。

ELI5データセットを除くとBARTの精度が高い。

ELI5データセットではLanguage Modelを事前学習したモデルのperpleixtyが一番良くなっていますが、それ以外のデータセットではすべてBARTが他の手法を上回っています。

大きなモデルを学習

RoBERTaやGPTなどのように大きなモデルを学習することで、精度がどのようになるかを見ていきます。

設定

レイヤー数をエンコーダー、デコーダーそれぞれ12層とし、隠れ層の次元は1024次元とします。

バッチサイズは8,000で500,000ステップ学習させます。

事前学習の方法はtext infillingとsentence permutationを使います。

つまり、1単語ではなく連続した複数単語を一つの[MASK]で置き換え、そして文章をシャッフルします。

学習に使うデータセットはRoBERTaと同じものを使い合計160GBになります。

詳細はこちらをご参照ください。

結果

分類タスク

こちらは分類タスクの結果です。

データセットによって多少違いますが、RoBERTaとほぼ同程度の精度です。

RoBERTaと同じ事前学習データを使っているので、これらのタスクに対しては同じ結果と言ってよいと思います。

CoLAデータセット(文法が合っているか間違っているかを当てるタスク)では、BARTが良くないのが少し気になります。

生成タスク

精度が改善することが期待される生成タスクのうち、文章要約の結果です。

R1、R2、RLはRougeという指標ですが、Rougeについてはこちらをご参照ください。

簡単にいうと、答えと予測で単語の重複がどれだけあるか、などを指標としているものです。

まず、左側のCNN/DailyMailデータセットの特徴は、要約が文書中の文章と似ているということで、抽出型のモデルがうまくいくようです。

一番上のLead-3というのが、最初の3つの文章を要約とした場合ですが、それでもそこそこうまくいっています。

しかしながらBARTは、現状のモデルを2ポイント程度上回っていることがわかります。

XSumデータセットはCNN/DailyMailデータセットよりも抽象度が高く、抽出型のモデルではうまくいきません。

実際、Lead-3だとかなり精度が低くなっています。

他にもデータセットがあるので興味がある方は論文を確認していただければと思います。

追加の分析

要約タスクで精度が改善していることはわかりましたが、ここでは実際のサンプルでその精度を見ていきます。

以下は、WikiNewsのテストサンプルです。

最後の例だと、元の文章は以下です。

PG&Eは、乾燥した状態での強風が予測されているため、停電を予定していると述べました。目的は山火事のリスクを減らすためとのことです。少なくとも明日の正午まで続くと見られ、約80万人の顧客がその停電の影響を受けます。

要約はこのようになっています(訳が下手ですみません)。

カリフォルニアでの計画停電により、数百万の顧客の電源が落とされます。

若干違うところもありそうですが、文法的には正しく、PG&E社がカリフォルニアの電力会社であるという知識も使って要約文を生成しています。

まとめ

今回は、BERTをSeq2Seqの形にした「BART」を見てきました。

Seq2Seqにすることで色々な事前学習方法をが試せるようになっています。

そして、やはり文書要約やabstractive Question Answeringなどのタスクでは非常に精度が良くなっています。

論文中には他にも実験結果が載っているので、興味のある方はご覧いただければと思います。

では!

mm0824

システム開発会社や金融機関で統計や金融工学を使ったモデリング・分析業務を長く担当してきました。

現在はコンサルティング会社のデータ・サイエンティストとして機械学習、自然言語処理技術を使ったモデル構築・データ分析を担当しています。

皆様の業務や勉強のお役に立てれば嬉しいです。

mm0824をフォローする
AI・機械学習 自然言語処理
mm0824をフォローする
楽しみながら理解する自然言語処理入門

コメント

タイトルとURLをコピーしました