LSTMの識別モデルを変えたら正答率が下がった話

あらすじ

前回、セリフの字面からアムロかシャアの識別を行ったら正答率73%だった。 obq777.hatenablog.com

今回

前回はデータの数が圧倒的に足りないとの考察を行いましたが、
スタンダードなLSTMを使ったのも一因かとも思ったので、self attentionを足してみました。

参考URL

qiita.com こちらのサイトによると、self attentionを用いることで、機械学習では難しい識別の根拠が得られるとのことで面白そうだと思いました。

モデル

前処理が前回と同じなので、割愛。
参考URL様のコードをもとにself attentionを追加しましたが、参考と変わらないので割愛。

根拠

Jupyterは直接HTMLがいじれるので、そのようにコードを変えました。
LSTMが「この単語はアムロ的」「これはシャア的」のような感じでやってると思ってます。(直感)
色が濃いほど根拠の確度が高く、正しい識別のときは赤間違えているときは青です

def highlight(word, attn, tmp):
    if tmp:
        html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    else:
        html_color = '#%02X%02X%02X' % (int(255*(1 - attn)), int(255*(1 - attn)), 255)
    return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(sentence, attns, length, tmp):
    html = ""
    sentence = sentence.numpy().tolist()
    sentence = list(chain.from_iterable(sentence))
    for i, word in enumerate(sentence[:length]):
        html += ' ' + highlight(
            idx2word[word],
            attns[i], tmp
        )
    return html + "<br><br>\n"

hoge = []
for i, batch in enumerate(test_loader):
    x, y, length = batch
    encoder_outputs = encoder(x, length)
    output, attn = classifier(encoder_outputs)
    pred = output.data.max(1, keepdim=True)[1]
    a = attn.data[0, :, 0]
    pred = output.data.max(1, keepdim=True)[1]
    tmp = (pred == y.data)
    hoge.append(mk_html(x, a, length, tmp))

from IPython.display import HTML, display
for i in hoge:
    display(HTML(i))

結果

正答率

(前回 73%) → 70%
40epochほど回しましたが70%と奮いませんでした。
かなりの確度でデータ数のせいだと思っていますが、ハイパーパラメータの方もいじるといいかもしれません。

根拠

f:id:obq777:20180814231300p:plain

識別の根拠例
相手の名前を言ってるところなんかを根拠としてるのは良いと思います...
セリフが短いと使えないですね。
LSTMよりも畳み込みニューラルネットワークの方が精度出る気もしてきます。