LSTMで逆シャアのセリフを識別する(Pytorch)

LSTMを使って、機動戦士ガンダム 逆襲のシャア(以後、逆シャア)のセリフが誰のものかを識別しました。
注) 正答率が73%とあまり良くないので、あまりいい記述ではないのかもしれません。ごめんなさい。

環境

データ収集

lovegundam.dtiblog.com

こちらのブログから逆シャアのセリフをコピーして、テキストファイルにペーストしました。(data.txt)
ついでにimportも済ませておきます。

import pandas as pd
import numpy as np
import re
from collections import Counter
import MeCab

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable

with open('data.txt') as f:
    data = f.readlines()
print(data[0:10], len(data))

データの整形

「誰が喋ったかの情報」と「何を喋ったか」に別けたいためCSVファイルを作成します。
*) 区切り文字がタブですが、便宜上CSVファイルとします。

person = 0
for line, serif in enumerate(data):
    data[line] = serif.replace('『', '「').replace('』', '」').strip()
    if serif[0] == '「':
        data[person] += serif
        data[line] = ''
    else:
        person = line
data = [x for x in data if x]

# 人名の抽出
csv_data = []
for line, serif in enumerate(data):
    text = serif.replace('\n', '').split('「')
    speaker = re.sub(r'[a-zA-Z]$', '', text[0])
    com = ''.join(text[1:])
    com = com.replace('「', '').replace('」', '')
    csv_data.append([speaker, com])
    
data = pd.DataFrame(csv_data, columns=['speaker', 'serif'])
data.to_csv('data.csv', sep='\t', index=False)

情報のふるい分け

事前に回したのですが、50個以上セリフがあるキャラクター10人に絞っても、正答率11%くらいしか出なかったので、アムロとシャアだけに絞ります。

data = pd.read_table('data.csv')
speakers = ['シャア', 'アムロ']
data = data[data['speaker'].isin(speakers)]

data = pd.DataFrame(data, columns=['speaker', 'serif'])
data.to_csv('data2.csv', sep='\t', index=False)

辞書の作成

学習するとき用のセリフの単語辞書とキャラクターの辞書を作成します。
thresholdは単語の出現回数の閾値です。2以上を設定すると、出現回数がthreshold以下のものは単語辞書に登録されず、<unk>に置きかわります。

word2idx = {}
idx2word = []
tagger = MeCab.Tagger("-Owakati")
threshold = 1

def add_word(word):
    if word not in word2idx:
        idx2word.append(word)
        word2idx[word] = len(idx2word) - 1
        
def vocab(word):
    if word not in word2idx:
        return word2idx['<unk>']
    else:
        return word2idx[word]
    
def wakati(target):
    target = tagger.parse(target).replace('\n', '')
    return '<bos> ' + target + '<eos>'

def assign_id(text, length):
    words = text.split()
    ids = torch.zeros(length).long()
    for token, word in enumerate(words):
        ids[token] = vocab(word)
    return ids

data = pd.read_table('data2.csv')
counter = Counter()
for target in data['serif']:
    words = wakati(target)
    words = words.split()
    counter.update(words)
    
words = [word for word, cnt in counter.items() if cnt >= threshold]

add_word('<pad>')
add_word('<bos>')
add_word('<eos>')
add_word('<unk>')

for word in words:
    add_word(word)

for speaker in list(set(data['speaker'])):
    idx2speaker.append(speaker)
    speaker2idx[speaker] = len(idx2speaker) - 1

DataLoaderの作成

学習用にDataLoaderを作成します。今回は「バッチサイズ = 1」で作成します。 2以上のときは、後述のpack_padded_sequneceでバッチごとに文の長さを降順にしないといけないのですが、実装がたるかったので今回はしていません。
特にvalidデータを使うとか考えてなかったので、単純にデータを
「学習:テスト = 9:1」にしています。

def get_dataset(texts, labels, bsz):
    data = []
    max_len = 0
    for i, target in enumerate(texts):
        hoge = wakati(target)
        data.append(hoge)
        if max_len < hoge.count(' ') + 1:
            max_len = hoge.count(' ') + 1
    
    print(len(data), max_len)
    
    dataset = []
    for i, target in enumerate(data):
        ids = assign_id(target, max_len)
        fuga = (ids, speaker2idx[labels[i]], target.count(' ') + 1)
        dataset.append(fuga)
        
    return DataLoader(dataset, batch_size=bsz, shuffle=True, drop_last=True)

data_size = len(data)
train_size = int(data_size * 0.9)
bsz = 1
train_loader = get_dataset(data['serif'][:train_size], list(data['speaker'][:train_size]), 1)
test_loader = get_dataset(data['serif'][train_size+1:], list(data['speaker'][train_size+1:]), 1)

モデルの作成

Dropoutを入れるべきかと思いましたが、入れた方が精度が悪かったので抜きました。

class estLSTM(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, label_size):
        super(estLSTM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.core = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, label_size)
        
    def forward(self, serif, length):
        emb = self.embed(serif)
        emb = torch.transpose(emb, 0, 1)
        emb = nn.utils.rnn.pack_padded_sequence(emb, length)
        outputs, h = self.core(emb)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs)[0]
        output = self.linear(outputs[-1])
        return output

学習

ハイパーパラメータは適当です。特に検証もしていません。

use_gpu = torch.cuda.is_available()
embed_size = 256
hidden_size = 256
num_layers = 1
est = estLSTM(embed_size, hidden_size, len(word2idx), num_layers, len(speaker2idx))

if use_gpu:
    est = est.cuda()

lr = 4e-4
cri = nn.CrossEntropyLoss()
params = list(est.parameters())
optimizer = torch.optim.Adam(params, lr = lr)

num_epochs = 10000
best_loss = 10000

for epoch in range(num_epochs):
    epoch_loss = 0
    for i, (text, label, length) in enumerate(train_loader):
        if use_gpu:
            text = text.cuda()
            label = label.cuda()
        text = Variable(text)
        label = Variable(label)
        output = est(text, length)
        
        loss = cri(output, label)
        est.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.data[0]
        
    epoch_loss /= total_step  
    print('Epoch [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
         .format(epoch+1, num_epochs, epoch_loss, np.exp(epoch_loss)))
        
    if best_loss > epoch_loss:
        best_loss = epoch_loss
        with open('speaker-est.pt', 'wb') as f:
            torch.save(est.state_dict(), f)

テスト

手元の環境だと50epochくらいでまぁまぁ収束していました。

PATH = 'speaker-est.pt'
est.load_state_dict(torch.load(PATH, map_location='cpu'))
# テスト
acc = 0
for i, (text, label, length) in enumerate(test_loader):
    text = Variable(text)
    output = est(text, length)
    _, predicted = torch.max(output.data, 1)
    
    if label == predicted:
        acc += 1
        
print('Accurency = {} / {} ({})'.format(acc, i+1, acc / i+1))

考察

Accurency = 22 / 30 (0.7333333)
と73%の正答率を得ました。ランダムだと50%なので、ギリギリ及第点なのではないのでしょうか。
本当は95%くらい欲しいですが
学習に使ったデータ数が2人合わせて271個(他テスト:30個)だったので大分少ないという印象です。
戦闘中なんかは目に見えて、セリフが短くなる(「うわっ」とか)。題材選びがまずかった気がします。