ギークなエンジニアを目指す男

機械学習系の知識を蓄えようとするブログ

単語と図で理解する自然言語処理(word2vec, RNN, LSTM)後編

前回に引き続き、後編です。

www.takapy.work

前編の簡単な復習

前編では、コンピュータに単語の意味を理解させる表現方法として下記の3種類を説明し、その中でもword2vecが有効な手段ということがわかりました。

  • シソーラスによる手法
  • カウントベースの手法
  • 推論ベースの手法(word2vec)

word2vecで使用されるニューラルネットワークのモデルには、CBOWskip-gramという2つのモデルがありました。

これらは、特定のコーパスを用いてニューラルネットワークモデルの学習をすることにより、その過程で計算される重みが結果として単語の分散表現となるという事を学びました。

本日はword2vecにも使われているフィードフォワード型のニューラルネットワークの問題点と、その解消方法としてRNNLSTMについて学んでいきます。

言語モデル

まず始めに言語モデルという概念を紹介しておく。

言語モデルとは単語の並びに対して確率を与えることである。単語の並びに対してそれがどれだけ起こり得るか、それがどれだけ自然な単語の並びであるかを確率で評価することである。

例えば、you say goodbyeという単語の並びには高い確率を、you say good dieという単語の並びには低い確率を出力する、というようなことを言語モデルは行う。(good dieなんて使う機会ないですよね、多分。)

言語モデルにおけるフィードフォワード型ニューラルネットワーク(word2vec)の問題点

先のCBOWモデルによって、例えばコンテキストを左側の2つの単語に限定することで、近似的に言語モデルを表現することは可能だが、このコンテキストのサイズは固定する必要があり、そのコンテキストよりもさらに左側にある単語の情報は無視されてしまう。

f:id:taxa_program:20190107232635p:plain
コンテキストとして左側のウィンドウだけを対象とする

コンテキストのサイズをとても大きな値(例:100など)にすることも可能だが、CBOWモデル(フィードフォワードのニューラルネットワーク)ではコンテキスト内の並びが無視されてしまうという問題がある。

蛇足だが、CBOWとはcontinuous bag-of-wordsの略で、bag-of-wordとは「袋の中にある単語」を意味し、これは袋の中の単語の「並び」が無視されることを表現している。
また、word2vecは、単語の分散表現を獲得することを目的に考案された手法であり、言語モデルとして利用されることはほとんどない。RNNによる言語モデルでも単語の分散表現は獲得できるが、語彙数増加への対応や単語の分散表現の質の向上のために、word2vecが提案された。という背景がある。

このような問題を解決するために登場するのが、リカンレントニューラルネットワーク、略してRNNである。

RNN

Recurrent Neural Networkを直訳すると循環するニューラルネットワークとなる。

循環するためには「閉じた経路」もしくは「ループする経路」が必要である。 データがループすることによって、情報は絶えず更新されることになる。

RNNレイヤ(RNNで用いられるレイヤ)を図で表すと、下記のようなイメージになる。

f:id:taxa_program:20181225003431p:plain
RNNレイヤイメージ

 x_tは時系列データを想定

このループ構造を展開することで、右方向に伸びるニューラルネットワークへと変形させることができる。

f:id:taxa_program:20181225005739p:plain
RNNのループ展開後イメージ

各時刻のRNNレイヤはそのレイヤへの入力一つ前のRNNレイヤからの出力を受け取っている。この2つの情報をもとに、その時刻の出力が計算される。(ここでいう時刻tとは自然言語処理の場合におけるt番目の単語という意味と同意)
この時の計算式は下記の数式で表すことができる。

 h_t=tanh( h_{t-1}W_h+x_tW_x+b )

数式を見ると、現在の出力  h_t は、ひとつ前の出力  h_{t-1} によって計算されることがわかる。
これは、RNNは h という状態を持っており、上記の式で更新されると解釈できる。

また、RNNレイヤを計算グラフで表現すると下記のようになる。

f:id:taxa_program:20190108231044p:plain
RNNレイヤの計算グラフ

RNNのh は「状態」を記憶し、時間が1ステップ(1単位)進むに従い上記数式の形で更新される。多くの場合、RNNのh_tは、隠れ状態隠れ状態ベクトルと呼ばれる。

BPTT(Backpropagation Through Time)

ループを展開した後のRNNについても、通常のニューラルネットワークと同様に誤差逆伝播法を用いることができる。 ここでの誤差逆伝播法は時間方向に展開したニューラルネットワークの誤差逆伝播法ということで、BPTT(Backpropagation Through Time)と呼ばれる。

BPTTの問題点

問題点として、長い時系列データを学習する際にBPTTで消費するコンピュータリソースが膨大になってしまう点が挙げられる。(各時刻のRNNレイヤの中間データをメモリに保持しておく必要があるため)
また、時間サイズが長くなると、逆伝播時の勾配が不安定になることも問題になる。

そこで次に述べるように、ある一定の長さでネットワークの繋がりを断ち切る必要がでてきます。

Truncated BPTT

ネットワークの逆伝播の繋がりだけを断ち切る手法である。 順伝播の流れは途切れることなく伝播させ、逆伝播のつながりは適当な長さで切り取り、その切り取られたネットワーク単位で学習を行う。

順伝播は切断させないため、RNNの学習を行う際はデータを順番に(シーケンシャルに)与える必要がある。シーケンシャルに与えないと、単語の並びを意識した学習ができなくなってしまう。 (ミニバッチ学習のように、データをランダムに選ぶことは推奨されない)

Truncated BPTTのミニバッチ学習

ミニバッチ学習を行うためには本来ならばバッチを考慮して、シーケンシャルにデータを与える必要がある。 そのために、データを与える開始位置をズラす必要がある。

例えば1000個の長さの時系列データに対して、時間の長さを10個単位で切るTruncated BPTTで学習する場合、ミニバッチのバッチ数を2として学習するにはどうしたら良いでしょう。 その場合、RNNレイヤの入力データとして、1つ目のバッチには、先頭から順にデータを与えていき、2つ目のバッチには、500番目のデータを開始位置としてそこから順にデータを与えていく必要がある。

つまり、開始位置を500だけズラすということになる。

長い時系列データを処理するとき、RNNの隠れ状態を維持する必要がある。このような隠れ状態を維持する機能はstatefulという言葉で表現され、多くのディープラーニングのフレームワークの引数としてstatefulが存在している。

RNN言語モデルの全体図

下記がざっくりとしたRNN言語モデル(RNNLM)の全体イメージ。

f:id:taxa_program:20181225010306p:plain
RNNのネットワーク図

最初の層はEmbeddingレイヤとなっており、単語IDを単語の分散表現(単語ベクトル)へと変換する。

そしてその分散表現が、RNNレイヤへと入力されます。このRNNレイヤは隠れ状態を次の層へ(図でいうと上方向)出力すると同時に、次時刻のRNNレイヤへ(図でいうと右方向)出力する。

そして、RNNレイヤが上方向に出力した隠れ状態は、Affineレイヤを経て、Softmaxレイヤと伝わり、確率が出力される。

このニューラルネットワークに対して、順伝播だけを考えて具体的なデータを流してみると、下記のようになる。

ここでの文章はYou say goodbye and I say hello.とする。

f:id:taxa_program:20190108000309p:plain
「You say goodbye and I say hello.」を処理するRNN

ここで出力している「say」や「goodbye」に関しては、Softmaxレイヤが出力する確率分布と共に記載している。

「say」の出力については、「goodbye」と「hello」の両方の確率分布が高くなることが予想できる。(you say goodbyeでも、you say helloでも文脈はおかしくない)

注目する点としては、RNNレイヤは「you say」という文脈を記憶しているという部分である。
より正確な表現をすると、「you say」という過去の情報をコンパクトな隠れ状態ベクトルとしてRNNが保持しているということになる。

LSTM(ゲート付きRNN)とは

上記で述べたRNNはあまり性能が良くない。その原因は、時系列データの長期依存関係をうまく学習できない点にある。

RNNLMの問題点

少し復習してみる。言語モデルが行う事は、これまでに与えられた単語から次に出現する単語を予測することであった。そこで、下記タスクを考えてみる。

f:id:taxa_program:20190108002455p:plain
「?」に入る単語は?

これをRNNLMが正しく答えるには、現在の文脈として「Tomが部屋でテレビを見ている事」そして「その部屋にMaryが入ってきた事」を記憶しておく必要がある。(このような状態を隠れ状態にエンコードして保持しておく必要がある)

この時、BPTTは下記のようになります。

f:id:taxa_program:20190108004224p:plain
正解ラベルが「Tom」であることを学習するときの勾配の流れ

RNNレイヤが過去方向に「意味のある勾配」を伝達することによって、時間方向の依存関係を学習することが可能になる。しかし、この勾配が途中で弱まってしまい、ほとんど何も情報を持たなくなってしまったら、重みパラメータは更新されなくなってしまう。

上記のようなシンプルなRNNレイヤのままでは、時間を遡るにつれて勾配が小さくなる勾配消失もしくは大きくなる勾配爆発のどちらかの運命を辿ってしまう。

勾配爆発への対策

勾配クリッピングと呼ばれる手法を用いることで解決することができる。
→簡単に説明すると、勾配のL2ノルムが閾値を超えてしまった場合、勾配を修正するというもの。

勾配消失への対策

勾配消失を解決するには、RNNレイヤのアーキテクチャを根本から変える必要がある。ここで登場するのがLSTM(ゲート付きRNN)である。

LSTM(Long short-term memory)の概要

LSTMとRNNのインターフェースを比較すると、下記のようになる。

f:id:taxa_program:20190108232308p:plain
RNNレイヤとLSTMレイヤの比較

大きな違いとして、LSTMレイヤにはcという経路が存在することが挙げられる。このcは記憶セルと呼ばれ、LSTM専用の記憶部に相当する。

この記憶セルの特徴は、それが自分自身だけで(LSTMレイヤ内だけで)データの受け渡しをするという部分である。

LSTMがゲート付きRNNと呼ばれる所以は、文字通りいくつかの「ゲート(門)」を機能として持っているからである。 このゲートがあることにより、下層から流れてきたデータを、ゲートの開き具合によって次層にどの程度流すかを制御することが可能となる。

このゲートの開き具合をコントロールするために、それ専用の重みパラメータが用いられる。この重みパラメータは学習によって更新される。(ゲートの開き具合もデータから自動的に学習させる)ゲートの開き具合を求めるにはsigmoid関数を使用する。(sigmoid関数の出力は0.0〜1.0の間の実数)

ここから、LSTMに備わっている様々なゲートの種類を見ていく。

LSTMに備わっているゲートについて

ゲートの種類としては、下記4つが挙げられる。

  • outputゲート
  • forgetゲート
  • inputゲート

また、後述するが、記憶セルに新しい情報を覚えさせるための「新しい記憶セル」も実装する。

先に最終的な計算グラフを図示しておく。

ここでσはsigmoid関数を表しており、×に関しては要素毎の積(アダマール積)を表す。tanhに関しては、 tanh( h_{t-1}W_h+x_tW_x+b ) の式を表している。

f:id:taxa_program:20190108235624p:plain
LSTMの計算グラフ

outputゲート(o)

tanh(c_t)の各要素に対して「それらが次時刻の隠れ状態としてどれだけ重要か」ということを調整する。

forgetゲート(f)

記憶セルに対して「何を忘れるか」を明示的に支持する。

新しい記憶セル(g)

新しく覚えるべき情報を記憶セルに追加する。

inputゲート(i)

新しい記憶セル(g)の各要素が、新たに追加する情報としてどれだけ価値があるかを判断する。このinputゲートによって、何も考えずに新しい情報を追加するのではなく、追加する情報の取捨選択を行う。

tanhの出力は-1.0〜1.0の実数である。この-1.0〜1.0の数値には、何らかのエンコードされた「情報」に対する強弱が表されていると解釈できる。一方、sigmoid関数の出力は0.0〜1.0である。これはデータをどれだけ通すかの割合を表す。そのため多くの場合、ゲートではsigmoid関数、実質的な「情報」を持つデータにはtanh関数が活性化関数として利用される。

なぜLSTMだと勾配消失が起きにくいのか

記憶セルcの逆伝播に着目すると、「+」ノードと「×」ノードだけを通ることが分かる。

f:id:taxa_program:20190109003210p:plain
記憶セルの逆伝播

「+」ノードの逆伝播は上流から伝わる勾配をそのまま流すだけなので、勾配の変化(劣化)は起きない。残る「×」ノードに関して、これは「行列の積」ではなく「要素ごとの積(アダマール積)」であり、毎時刻、異なるゲート値によって要素毎の積の計算が実施される。ここに勾配消失を起こさない理由がある。

「×」ノードの計算はforgetゲートによってコントロールされている。ここで、forgetノードが「忘れるべき」と判断した記憶セルの要素に対しては、その勾配の要素は小さくなる。一方で、forgetゲートが「忘れてはいけない」と判断した要素に対しては、その勾配の要素は劣化することなく過去方向へ伝わる。そのため、記憶セルの勾配は、(長期にわたって覚えておくべき情報に対しては)勾配消失を起こさずに伝播することが期待できる。

(考えついた研究者は天才か)

LSTMの改善案

詳しくは触れないが、下記のような手法でモデルの精度が向上する可能性がある。

  • LSTMレイヤの多層化

  • Dropout

  • 重み共有

最後に

前編、後編とまとめてみたが、図示することで数式だけよりも頭に入りやすかった。しかし、理論を勉強しただけではこれをどのように実際のビジネスやシステムに取り入れていけば良いかイメージし辛いと感じているのも事実である。 (例えば、単語の分散表現を得るためにはCBOWとLSTMどちらを利用するべきか、など。個人的な理解だと、計算リソース、計算時間ではCBOWが優勢、分散表現の精度ではLSTMが優勢?)

この辺りは、今行っている口コミのデータ解析や、kaggleのコンペ等を通してスキルアップさせていきたい。