RNNの訓練手法の比較

RNNの訓練手法の比較

こんにちは,Nextremer 高知オフィスのインターン生 藤田と申します.

Recurrent Neural Network (RNN) の学習時に,モデルに対するデータの与え方によっていくつかの手法があります.
いくつかを実装して翻訳タスクの実験をしてみましたので,その結果を皆さんに共有させていただきます.

はじめに

多くの方は,RNNの訓練時はタイムステップ毎にモデルに対して教師データを与えていると思います.この方法は,Teacher forcing と呼ばれています.
Teacher Forcing は,一般的に広く用いられている印象ですが, Exposure bias と呼ばれる問題が指摘されています[1].

Exposure bias とは,デコーダに与えるデータが訓練時とテスト時で異なるために,テスト中に一度正解から外れると誤差が累積してしまう問題です.テスト時に用いる入力は,デコーダ自身が予測したデータとなるため,このような問題が生じます.
一方で,モデルが出力したデータを利用する手法はFree running と呼ばれますが,Teacher forcing に比べて学習が不安定で収束が遅い傾向にあります.そこで,モデルの性能はある程度高く,exposure bias を抑制する手法として,Scheduled Sampling [2] と Semi-teacher-forced training [3](本記事では,便宜上 Semi-teacher forcing と表記することにします) と呼ばれるものが提案されています.

このように RNN にはいくつかの訓練手法がありますが,結局のところどれを使ったらいいのかは一般的な解釈が存在せず,とりあえず Teacher forcing を使っているというのが現状ではないでしょうか?

この記事では,Teacher forcing / Free running / Scheculed sampling / Semi-teacher forcing を翻訳精度と学習に要する時間の観点で比較した結果を紹介します.
結論を先に言うと,KFTT(京都フリー翻訳タスク)においては,Teacher forcing が良いということになりました.

実験で使用する訓練手法

ここでは,各訓練手法を sequence-to-sequence モデル[6]をベースに説明します.
入力系列を \(X = (x_1, …, x_S)\) , 出力系列 \(\hat{Y} = (\hat{Y}_0, \hat{Y}_1, …, \hat{Y}_{T\prime})\), 正解系列を \(Y = (y_0, y_1, …, y_T)\), X のエンコードによって得られる隠れ変数を \(h_S\) と定義します.ただし,\(y_0\), \(y_1\) は,特別な記号 “<EOS>” を表すものとします.

このモデルの目的は,条件付き確率 \(p(y_1, …, y_T | x_1, …, x_S)\) を推定することであり,各訓練手法では,その確率の計算に登場する \(d_t\) の値(図1~4を参照)が異なります.

Teacher frocing

各タイムステップの入力に教師データをそのまま使う手法です.
常に教師データを用いるため,\(d_t = y_{t-1}\) となります.
この手法は,学習が安定し収束は早いのですが,推定後のモデルでは Exposure bias が生じます.

図1. Teacher forcing (TF)
図1. Teacher forcing (TF)

Free running

直前のタイムステップにモデルからサンプルしたデータを入力として使う手法です.
モデル自身が出力したデータのみを用いるため, \(d_t = \hat{y}_{t-1}\) となります.
この手法では,学習時とテスト時ともにモデルからサンプルしたデータを使用するため Exposure bias は生じません.しかし,連鎖的な誤差の増大により学習が不安定となり収束が遅くなります.

図2. Free running (FR)
図2. Free running (FR)

Scheduled sampling [2]

Teacher forcing と Free running を組み合わせた手法です.デコーダ側の各タイムステップにおいて,教師データまたはモデルからサンプルしたデータを選択して使用します.よって,\(d_t\) は \(y_{t-1}\) または \(\hat{y}_ {t-1}\) となります.このとき,教師データを選択する確率は\( \epsilon \)で表され,訓練中はこの値を 「1 → 0」 と徐々に減衰させます.これは,学習初期は教師データの利用頻度を高くし学習の高速化を図ることが狙いです.文献 [2] では,この減衰手法として図4に示す3通りを紹介しています.

図3. Scheduled sampling (SS)
図3. Scheduled sampling (SS)
図4. 確率\( \epsilon \)の減衰手法
図4. 確率\( \epsilon \)の減衰手法

 

 

Semi-teacher forcing [3]

この手法は,もともとは感情音声合成のタスクで提案されたもので,教師データとサンプルデータを相加平均してデコーダに与えるという手法です.自然言語処理の分野でも適用できそうなので,今回の実験に採用しました.教師データとサンプルデータを両方使うという点では,Scheduled sampling と似たような手法ですが,デコード時に教師データとサンプルデータを混ぜ合わせて使用する点が異なります.
今回の実験でこの手法を用いるために 2 つ工夫をしました.
1つ目が,教師データとサンプルデータの埋め込みベクトルを平均するというものです.というのも,モデルからサンプルできるのは one-hot ベクトルなので,教師データとの平均を単純に計算できないからです.
2つ目が,加重平均をとり,教師データに対する重みをスケジューリングさせるというものです.これは,学習の安定化と学習後期に Free running に近い手法にさせるためです.
したがって,この手法における \(d_t\) は教師データ側の重み \(a\) を用いて図5のように表されます.

図5. Semi-teacher forcing (STF)
図5. Semi-teacher forcing (STF)

実験の設定

Chainer を用いて京都フリー翻訳タスク(KFTT)[4]とよばれる日英機械翻訳タスクを行いました.
モデルの実装は,GitHub に公開されている sequence-to-sequence モデルのサンプルコード(https://github.com/chainer/chainer/tree/master/examples/seq2seq) を再利用させていただきました.これに,Teacher forcing を除く3つの訓練手法とattentionを独自に追加したものを使いました.
また,作成したモデルの評価には, BLEU を用いました.
データセットとパラメータを以下に記します.

データセット

今回使用するデータセットは,KFTTのデータセットは http://www.phontron.com/kftt/index-ja.html からダウンロードしたもので,(Kyoto Free Translation Task (Data Only v. 1.0)) を使用しました.
KFTTのデータセットは http://www.phontron.com/kftt/index-ja.html からダウンロードしたもの (Kyoto Free Translation Task (Data Only v. 1.0)) を使用しました.
前処理の内容は下記の通りです.

【デフォルトで行われている前処理】

  • Mosesのスクリプトによる英語のトークン化
  •  KyTeaを用いた日本語の単語分割
  • データを4分割 (train / tune / dev / test)

【独自に行なった前処理】

  • 英字の小文字
  • train / dev に関しては,source と target それぞれに含まれる長さ51以上の文を除外

このデータセットの中から,3種類のデータ(train / dev / test)を実験で使用します.
各データの文の数はそれぞれ以下のとおりです.

  • train: 376,961
  • dev: 1,166
  • test: 1,160

また,語彙辞書は,サンプルコードと同様のGitHubリポジトリに公開されているスクリプトを利用して,頻出単語 40,000個で作成しました.

パラメータ

モデルのパラメータは下記の通りです.

  • LSTMの層数: 2
  • ユニット数: 512
  • Dropout rate: 0.1
  • Optimizer: Adam
  • ミニバッチサイズ: 64

Scheduled sampling と Semi-teacher forcing のスケジューリングは,次の設定で行いました.

  • アルゴリズム:逆シグモイド減衰 (Inverse sigmoid decay)
  • レート:1/50k または 1/100k

レートが 1/50k の場合,学習イテレーションが 50k に達するときに確率 \( \epsilon \) が 0 になることを表します.
Semi-teacher forcing に関しては,[3] と同様に \(a\) を 0.5 に固定した実験も行いました.

実験結果

BLEU

dev と test データそれぞれから算出したBLEUを表1に示します.

表1. 各実験項目ごとの BLEU

手法 条件 dev BLEU test BLEU
TF 15.5 17.95
FR 4.70 6.10
SS-50K 1/50k で\( \epsilon \)を減衰 13.30 15.13
SS-100K 1/100k で\( \epsilon \)を減衰 14.55 16.80
STF-FIXED \(a\) = 0.5 で固定 15.73 17.61
STF-50K 1/50k で \(a\) を減衰 13.93 15.40
STF-100K 1/100k で \(a\) を減衰 14.78 17.23

dev BLEU は,STF-FIXED が最も高く,test BLEU は TF が最も高くなりました.TF と STF-FIXED を用いるとほぼ同水準の翻訳精度となるモデルができるということでしょうか?
Scheduled sampling よりも Semi-teacher forcing の方が BLEU は高くなる傾向が見受けられます.
また,FR に関しては他の手法に比べて,BLEU の低さが目立ちます.
図6は,いくつかの手法における学習中のBLEUの変化をに示しています.
SS-100K では 39K iter.,STF-100K では 70K iter. で BLEU が急激に減少してしまいました.

図6_BLEUの変化
図6_BLEUの変化

test データ 1160 文をの3つのグループに分割し,各グループごとに算出した BLEU を表2 に示します.

  • グループ1: 文の長さが 1 ~ 25 の 778 文で構成
  • グループ2: 文の長さが 26 ~ 50 の 282 文で構成
  • グループ3: 文の長さが 51 以上の 99 文で構成

Exposure bias を考慮すると,文長が長くなるほど,TF の BLEU の低下が予想されましたが,グループ2と3に対しては最も高い BLEU を記録しました.

表2. 文の長さごとに算出した BLEU

手法 グループ1 グループ2 グループ3
TF 21.64 16.45 12.14
FR 10.13 3.81 1.60
SS 20.84 14.59 11.84
STF-FIXED 22.24 15.45 11.72
STF-FIXED 21.36 16.07 10.74

スケジューリングの効果

図7から,Free running に近づくほど,BLEU が著しく低下することがわかります.また図8から,レートは 1/50k よりも 1/100k の方が,BLEUが高くなりました.

今回は Semi-teacher forcing に対して教師データの重みの減衰をしてみましたが,固定した方がBLEUは多少なりとも高くなるかもしれませんね.

図7. スケジューリングがBLEUにもたらす効果(左)SS-100K(右)STF-100K
図7. スケジューリングがBLEUにもたらす効果(左)SS-100K(右)STF-100K
図8. レートの大きさを変えた場合のBLEUの変化(左)SS-100K と SS-50K(右)STF-100K と STF-50K
図8. レートの大きさを変えた場合のBLEUの変化(左)SS-100K と SS-50K(右)STF-100K と STF-50K

生成文

まず,比較的短い文に対する翻訳結果を見てみましょう.

【ソース1】 道元 ( どうげん ) は 、 鎌倉 時代 初期 の 禅僧 。
【ターゲット1】 dogen was a zen monk in the early kamakura period .

TF: dogen was a zen priest in the early kamakura period .
FR:he was the soto soto sect soto sect
SS-100K: he was the founder of the soto sect .
STF-FIXED: he was the founder of the soto sect .
STF-100K:  he was the founder of the soto sect .
【ソース2】 晩年 に 希 玄 と い う 異称 も 用い た 。
【ターゲット2】later in his life he also went by the name kigen .

TF: in his later years , he used different names as  .
FR: in his later years , name name name was was . . . .
SS-100K: in his later years , he used the name  .
STF-FIXED: in his later years , he was also used as a name of  .
STF-100K: in his later years , he was also called  .
【ソース3】 一般 に は 道元 禅師 と 呼 ば れ る 。
【ターゲット3】he is generally called dogen zenji .

TF: he is generally called dogen .
FR: he was generally called as zenji zenji zenji
SS-100K: he is generally called dogen zenji .
STF-FIXED: it is generally called dogen zenji .
STF-100K: it is generally called dogen .

ソース2の翻訳結果にが現れたのは,学習データに ”希 玄” / ”kigen”を含む文が無く,その対応関係を学習できなかったためだと思われます.FR以外の翻訳精度は,概ね妥当であり,同等な水準であると言えるのではないでしょうか?

次に,57単語で構成される文(学習には使っていない51単語以上の文)の翻訳結果を見てみます.

【ソース4】 関西 本線 の 支線 と し て の 沿革 を 持 つ ため 、 正式 な 起点 は 木津 駅 だ が 、 列車 運行 上 は 京都 から 木津 へ 向か う 列車 が 下り ( 列車 番号 は 奇数 ) 、 逆 が 上り ( 同 偶数 ) と な っ て い る 。
【ターゲット4】although the nara line officially starts at kizu station because it is historically a branch line of the kansai main line , outbound trains ( odd-numbered trains ) run from kyoto to kizu and inbound trains ( even-numbered trains ) run in the opposite direction .
TF: because the station starts with the main branch of the kansai main line , the station starts to kizu , but the train heading toward kizu ( number number of trains ) is outbound , and the reverse number is the opposite track ( number of trains ) .

FR: because the is the the the the the line main main main main line line , , , , , , the the the the the the the trains trains trains trains trains , trains , , trains ) ) , , , , ) ) ) ) ) . . .

SS-100K: since the station has a history of the kansai main line , the station is kizu station , but the trains heading toward kizu from kyoto are operated by the train ( odd number number number number number number ) , and the trains heading toward kizu ( the number number number number number number ) are the same as that of the main line ( the number number of trains ) .

STF-FIXED: because the station is a branch line of kansai main line , the station is kizu station ( the number of trains that is operated by the train ) , but the train ( numbers ) are operated from kyoto to kizu ( numbers ) , and the reverse side is the same as that of the train .

STF-100K: although the formal origin of the kansai main line is kizu station , the formal trains run from kyoto station to kizu station ( odd-numbered trains ) , and the outbound train is called the inbound train ( odd number number number ) , and the opposite of the station is called " inbound train . "

比較的長い文に対しても,TF の翻訳精度は比較的良く感じます.(ただし,実用的なレベルではありません.)
exposure bias の影響は思ったより受けていないと見受けられます.

学習処理に要する時間

表3から,Teacher forcing は他の手法に比べて約 9 倍処理時間が短くなりました.
Free running / Scheduled sampling / Semi-teacher forcing のデコーダの処理は,各タイムステップごとに 「→ 埋め込み層 → LSTM層 → Attention層 → 全結合層」のように逐次的にデータを処理する必要があります.
一方,Teacher forcing では,そのような逐次処理が必要なのは「LSTM層」のみで,その他の「→ 埋め込み層」と「→Attention 層 → 全結合層」では,各タイムステップの処理をまとめて行えます.今回の実装ではそのようにしているため,表3のような時間の差が生まれたのだと思います.

表3. 1000イテレーションの学習にかかる処理時間(秒)

Teacher forcing Free running Scheduled sampling Semi-teacher forcing
238 2149 2156 2240

まとめ

  • 日英機械翻訳タスク(KFTT)にて4つの訓練手法をいくつかのバリエーションを加えて行った結果, test BLEUベースで,
    「TF > STF-FIXED > STF-100K > SS-100K >> FR」
  • 翻訳結果の主観評価では,概ね Free running 以外は同等の翻訳精度となった(ただし,実用的な精度とは言えない)
  • 学習処理に要する時間は今回利用したコードでは,Teacher forcing が他の3手法に比べて 約 9 倍短くなった(Teacher forcing の一部の処理は各タイムステップで必要な処理をまとめることが可能なため)

以上の結果が,RNNを実装する際の皆さんの判断材料になれば幸いです.

今回使用した KFTT のコーパスは,括弧書きが多いなど独特なものであったため,データセットを変えたり,または,対話生成や要約といったの他のタスクを行ってみると結果は変わるかもしれません.特に,Semi-teacher forcing は タスクによっては,Teacher forcing より高くなる可能性があるかもしれません.(時間の都合でそこまで実験はできませんでした)
また,今回は調査外としましたが、Teacher forcingとFree runningにおける最終隠れ状態の差を小さくなるようにGAN使って学習させる手法として Professor forcing があります.こちらを用いるとおもしろい結果が得られるかもしれません.興味のある方は試してみてください.

今回の実験では,Semi-teacher forcing / Scheduled sampling では,Free running に近づくほど,BLEUが著しく低下していまいましたが,文献[2]でも同じ現象が観測されたみたいです.著者らはモデルがサンプリング確率\( \epsilon \)の勾配を考慮していないことがその原因だと考察していました.サンプリング確率\( \epsilon \)を学習によって獲得するようにすることでBLEUが上がるかもしれないということですが,時間の都合上その検証はしていません.

参考文献

[1] M. A. Ranzato, S. Chopra, M. Auli, W. Zaremba, Sequence level training with
recurrent neural networks, ICLR, 2016.
[2] S. Bengio, O. Vinyals, N. Jaitly, N. Shazeer, Scheduled sampling for sequence prediction with recurrent neural networks, Advances in Neural Information Processing Systems, 2015.
[3] Y. Lee, A. Rabiee, S. Y. Lee, Emotional End-to-End Neural Speech synthesizer, Advances in Neural Information Processing Systems, 2017.
[4] The Kyoto Free Translation Task (KFTT), http://www.phontron.com/kftt/index-ja.html, 2018
[5] M. T. Luong, H. Pham, C. D. Manning, Effective approaches to attention-based neural machine translation, arXiv preprint arXiv:1508.04025, 2015.
[6] I. Sutskever, O. Vinyals, Q. V. Le, Sequence to Sequence Learning with Neural Networks, Advances in neural information processing systems, 2014.
[7] Encoder-decoderモデルとTeacher Forcing,Scheduled Sampling,Professor Forcing (最終閲覧日 2018/07/17), http://satopirka.com/2018/02/encoder-decoder%E3%83%A2%E3%83%87%E3%83%AB%E3%81%A8teacher-forcingscheduled-samplingprofessor-forcing/

【補足】実験で使用したモデル

本実験では,Attention付きSequence-to-sequence モデル(図9)を使用しました.
これは,機械翻訳や対話生成などのタスクなどで広く利用されていて,自然言語処理の分野ではこのモデルをベースにして様々な手法が提案されたりしています.

以下では,Sequence-to-sequence モデル と Attention モデルを簡単に紹介します.

 

図9_本実験で使用するAttention付きSequence-to-sequenceモデル
図9_本実験で使用するAttention付きSequence-to-sequenceモデル

Sequence-to-sequence モデル

系列変換を行う Encoder-Decoderモデル(図10)の一つです.
入力系列をEncoder と呼ばれる LSTM を用いて固定長ベクトルに変換し,Decoder と呼ばれる LSTM を用いて系列を生成します.

図10_Sequence-to-sequence モデルの概略図
図10_Sequence-to-sequence モデルの概略図

Attention モデル

Attentionモデルは,エンコーダ側の全てのタイムステップにおける 隠れ状態を加味して出力を予測するモデルです.Attentionの処理では,Attentionベクトル(図11)を計算することが目的となります.

図11_Attentionの処理
図11_Attentionの処理