論文紹介: "R-TRANSFORMER: RECURRENT NEURAL NETWORK ENHANCED TRANSFORMER"
https://arxiv.org/pdf/1907.05572.pdf
概要
Globalな情報はTransformerのSelf-Attentionで、Localな情報をRNNで取得するという手法の提案。Transformerは大域的な情報に強いものの局所情報はposition embeddingという限られた情報に依存しているため、これをRNNで代替/補強するというアイデア。言語モデルで優秀な精度を記録。
手法
3つの異なるネットワークが階層的に配置されている。最下層はlocalRNNで、一定の幅を持つ局所ウィンドウをシークエンスの並びに沿って徐々に動かしていく。RNNはそれぞれの局所ウィンドウの中で実行される。中間層はmulti-head attention networkで、グローバルな長いシークエンスの依存関係を捉える。最上層はfeedforward networkで、非線形の特徴抽出を行う。これら3つのネットワークは残差とlayer normalizationによって結合される。
結果
各データ(画像データ, 音声データ, 自然言語)それぞれについて従来の手法(RNN, GRU, LSTM, TCN, Transformer)と比較した結果いずれのデータにおいても最高の精度を叩き出した。R-transformerはRNNとmulti-head attention poolingの良いとこ取りと相互の弱点補完をしている。
コメント
画像データには有名な手書き数字画像MNISTが用いられたが、28x28の画像を784x1のシークエンスに変換してシークエンスモデルに読み込ませている点は興味深い。MNISTはkaggleにもコンペが出ているのでR-transformerが従来型のモデルより高い精度を出せるのか試してみるのも面白そうだ。