強化学習で迷路の最短経路を見つける(Julia)

概要

簡単な迷路の経路探索を強化学習を用いて行うプログラムを書きました。
コードと結果は末尾に記載しています。

背景

強化学習は存在を知っているレベルだったんですが、簡単に実装がてら、原理くらいは分かるように勉強しました。
つまり詳しいことは知らんです。
あと、Juliaの勉強です。Pythonで実装すると、どうしても他のサイトのコードをそのままコピペしてしまうので。

環境

  • Julia v1.0
  • Jupyter

強化学習で迷路を探索するとは

僕の所感ですが、何回も(千回レベル)同じ迷路を探索して良い道を学習するというのが近いです。
目的があり(今回は「ゴールする」+「点数が下がるパネルを踏まない」)、その目的に沿うように最適化していきます。
そのためには、自分が「どの位置」で「どの方向」に行けば良いかという指標が必要です。
その指標を学習によって獲得し、その指標通りに迷路を進むと最短経路で最高得点が得られるということです。

強化学習

今回はQ学習とε-greedy法を用いました。

Q学習 - Wikipedia

ε-greedy法とは簡単に述べると、たまにランダムに迷路を進むというものです。

余談

前述で述べた指標ですが、実装的に述べると(迷路の位置数) * (方向数)のデータ空間と考えると分かりやすいです。
今回の迷路は狭いので、このデータ空間は問題になりませんが、大きい問題になると明らかにやばいです。(マリオとか)
そこで、この指標をニューラルネットワークによって記述するのが「DQN (Deep Q-Network)」というものらしいです。

所感

迷路の探索問題(縦型・横型探索とか)は有名で、何回か実装しようとして面倒になってやめたことがあるのですが、これならできるかなと思いました。
今後はDeep Q-networkを実装するか、JuliaのGUIアプリを作成して迷路をアプリで作成できるようにしたいです。

付録

str_field = 
    """
    #,#,#,#,#,#,#
    #,S,0,0,-10,0,#
    #,0,-10,0,0,0,#
    #,0,-10,0,-10,0,#
    #,0,0,0,-10,0,#
    #,0,-10,0,0,100,#
    #,#,#,#,#,#,#
    """
start_point = (2, 2)
goal_point = (6, 6)
field = map(x -> split(x, ","), split(str_field))
alpha = 0.2
gamma  = 0.9
e_greedy_ratio = 0.2
epochs = 1000
Qvalue = Dict()
function display_field(state, action)
    tmp_field = []
    tmp_field = deepcopy(field)
    tmp_field[state[1]][state[2]] = "@"
    for i in 1:size(tmp_field, 1)
        println(tmp_field[i])
    end
    @show state, action
end
function get_actions(point)
    """ 現在位置から行くことが可能な座標を求める """
    x, y = point
    around_map = [(x, y-1), (x, y+1), (x-1, y), (x+1, y)]
    return [hoge for hoge in around_map if !(field[hoge[1]][hoge[2]] in ["#", "S"])]
end
function get_Qvalue(state, action)
    hoge = (state, action)
    try
        return Qvalue[hoge]
    catch
        return 0.0
    end
end
function choose_action_greedy(state)
    """ greedy法で行動を決める """
    best_actions = []
    max_q_value = -1
    for a in get_actions(state)
        q_value = get_Qvalue(state, a)
        if q_value > max_q_value
            best_actions = [a, ]
            max_q_value = q_value
        elseif q_value == max_q_value
            push!(best_actions, a)
        end
    end
    return rand(best_actions)
end
function choose_action(state)
    """ e-greedy法で行動を決める. """
    if e_greedy_ratio < rand()
        return rand(get_actions(state))
    else
        return choose_action_greedy(state)
    end
end
function set_Qvalue(state, action, q_value)
    hoge = (state, action)
    Qvalue[hoge] = q_value
end
function update_Qvalue(state, action)
    Q_s_a = get_Qvalue(state, action)
    mQ_s_a = maximum([get_Qvalue(action, n_action) for n_action in get_actions(action)])
    r_s_a = parse(Int64, field[action[1]][action[2]])
    # calculate
    q_value = Q_s_a + alpha * ( r_s_a +  gamma * mQ_s_a - Q_s_a)
    # update
    set_Qvalue(state, action, q_value)
    return r_s_a != 0.0
end
function train()
    """ 迷路の探索 """
    state = start_point
    while true
        action = choose_action(state)
        if update_Qvalue(state, action)
            break
        else
            state = action
        end
    end
end
function test()
    state = start_point
    count = 0
    reward = 0
    while true
        count += 1
        try
            reward += parse(Int64, field[state[1]][state[2]])
        catch
            reward += 0
        end
        action = choose_action_greedy(state)
        display_field(state, action)
        if state == goal_point
            @show count, reward
            break
        end
        state = action
    end
end
for i in 1:epochs
    train()
end

test()

結果

  • # -> 壁
  • @ -> 現在位置
  • 数字 -> パネルの得点(100がゴール)
  • S -> スタート地点
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "@", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((2, 2), (2, 3))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "@", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((2, 3), (2, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "@", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((2, 4), (3, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "@", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((3, 4), (4, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "@", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((4, 4), (5, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "@", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((5, 4), (6, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "@", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((6, 4), (6, 5))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "@", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((6, 5), (6, 6))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "@", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((6, 6), (6, 5))
(count, reward) = (9, 100)

MethodError: no method matching *(::Array{Float64,1}, ::Array{Float64,1})

Juliaのエラーで少し詰まったので共有しておきます。
端的に言うと、行列の積を計算するときに次元数を誤っています。
エラー文が分かりづらかったので詰まってしまいました。

環境

  • Julia v1.0.0

エラー

hoge = ones(3)
fuga = ones(3)
hoge * fuga

MethodError: no method matching *(::Array{Int64,1}, ::Array{Int64,1})

原因

Juliaの1次元配列は縦ベクトルなので、各配列の次元数が
hoge: 3×1行列
fuga: 3×1行列
となっているので、単純に行列の積ができないだけです。

解消

転置しました。

hoge = ones(3)
fuga = ones(3)
hoge' * fuga

3.0

所感

1次元配列が縦ベクトルというのが、他の言語と違い詰まってしまいました。
エラー文が正しくないと思う

LSTMの識別モデルを変えたら正答率が下がった話

あらすじ

前回、セリフの字面からアムロかシャアの識別を行ったら正答率73%だった。 obq777.hatenablog.com

今回

前回はデータの数が圧倒的に足りないとの考察を行いましたが、
スタンダードなLSTMを使ったのも一因かとも思ったので、self attentionを足してみました。

参考URL

qiita.com こちらのサイトによると、self attentionを用いることで、機械学習では難しい識別の根拠が得られるとのことで面白そうだと思いました。

モデル

前処理が前回と同じなので、割愛。
参考URL様のコードをもとにself attentionを追加しましたが、参考と変わらないので割愛。

根拠

Jupyterは直接HTMLがいじれるので、そのようにコードを変えました。
LSTMが「この単語はアムロ的」「これはシャア的」のような感じでやってると思ってます。(直感)
色が濃いほど根拠の確度が高く、正しい識別のときは赤間違えているときは青です

def highlight(word, attn, tmp):
    if tmp:
        html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    else:
        html_color = '#%02X%02X%02X' % (int(255*(1 - attn)), int(255*(1 - attn)), 255)
    return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(sentence, attns, length, tmp):
    html = ""
    sentence = sentence.numpy().tolist()
    sentence = list(chain.from_iterable(sentence))
    for i, word in enumerate(sentence[:length]):
        html += ' ' + highlight(
            idx2word[word],
            attns[i], tmp
        )
    return html + "<br><br>\n"

hoge = []
for i, batch in enumerate(test_loader):
    x, y, length = batch
    encoder_outputs = encoder(x, length)
    output, attn = classifier(encoder_outputs)
    pred = output.data.max(1, keepdim=True)[1]
    a = attn.data[0, :, 0]
    pred = output.data.max(1, keepdim=True)[1]
    tmp = (pred == y.data)
    hoge.append(mk_html(x, a, length, tmp))

from IPython.display import HTML, display
for i in hoge:
    display(HTML(i))

結果

正答率

(前回 73%) → 70%
40epochほど回しましたが70%と奮いませんでした。
かなりの確度でデータ数のせいだと思っていますが、ハイパーパラメータの方もいじるといいかもしれません。

根拠

f:id:obq777:20180814231300p:plain

識別の根拠例
相手の名前を言ってるところなんかを根拠としてるのは良いと思います...
セリフが短いと使えないですね。
LSTMよりも畳み込みニューラルネットワークの方が精度出る気もしてきます。

LSTMで逆シャアのセリフを識別する(Pytorch)

LSTMを使って、機動戦士ガンダム 逆襲のシャア(以後、逆シャア)のセリフが誰のものかを識別しました。
注) 正答率が73%とあまり良くないので、あまりいい記述ではないのかもしれません。ごめんなさい。

環境

データ収集

lovegundam.dtiblog.com

こちらのブログから逆シャアのセリフをコピーして、テキストファイルにペーストしました。(data.txt)
ついでにimportも済ませておきます。

import pandas as pd
import numpy as np
import re
from collections import Counter
import MeCab

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable

with open('data.txt') as f:
    data = f.readlines()
print(data[0:10], len(data))

データの整形

「誰が喋ったかの情報」と「何を喋ったか」に別けたいためCSVファイルを作成します。
*) 区切り文字がタブですが、便宜上CSVファイルとします。

person = 0
for line, serif in enumerate(data):
    data[line] = serif.replace('『', '「').replace('』', '」').strip()
    if serif[0] == '「':
        data[person] += serif
        data[line] = ''
    else:
        person = line
data = [x for x in data if x]

# 人名の抽出
csv_data = []
for line, serif in enumerate(data):
    text = serif.replace('\n', '').split('「')
    speaker = re.sub(r'[a-zA-Z]$', '', text[0])
    com = ''.join(text[1:])
    com = com.replace('「', '').replace('」', '')
    csv_data.append([speaker, com])
    
data = pd.DataFrame(csv_data, columns=['speaker', 'serif'])
data.to_csv('data.csv', sep='\t', index=False)

情報のふるい分け

事前に回したのですが、50個以上セリフがあるキャラクター10人に絞っても、正答率11%くらいしか出なかったので、アムロとシャアだけに絞ります。

data = pd.read_table('data.csv')
speakers = ['シャア', 'アムロ']
data = data[data['speaker'].isin(speakers)]

data = pd.DataFrame(data, columns=['speaker', 'serif'])
data.to_csv('data2.csv', sep='\t', index=False)

辞書の作成

学習するとき用のセリフの単語辞書とキャラクターの辞書を作成します。
thresholdは単語の出現回数の閾値です。2以上を設定すると、出現回数がthreshold以下のものは単語辞書に登録されず、<unk>に置きかわります。

word2idx = {}
idx2word = []
tagger = MeCab.Tagger("-Owakati")
threshold = 1

def add_word(word):
    if word not in word2idx:
        idx2word.append(word)
        word2idx[word] = len(idx2word) - 1
        
def vocab(word):
    if word not in word2idx:
        return word2idx['<unk>']
    else:
        return word2idx[word]
    
def wakati(target):
    target = tagger.parse(target).replace('\n', '')
    return '<bos> ' + target + '<eos>'

def assign_id(text, length):
    words = text.split()
    ids = torch.zeros(length).long()
    for token, word in enumerate(words):
        ids[token] = vocab(word)
    return ids

data = pd.read_table('data2.csv')
counter = Counter()
for target in data['serif']:
    words = wakati(target)
    words = words.split()
    counter.update(words)
    
words = [word for word, cnt in counter.items() if cnt >= threshold]

add_word('<pad>')
add_word('<bos>')
add_word('<eos>')
add_word('<unk>')

for word in words:
    add_word(word)

for speaker in list(set(data['speaker'])):
    idx2speaker.append(speaker)
    speaker2idx[speaker] = len(idx2speaker) - 1

DataLoaderの作成

学習用にDataLoaderを作成します。今回は「バッチサイズ = 1」で作成します。 2以上のときは、後述のpack_padded_sequneceでバッチごとに文の長さを降順にしないといけないのですが、実装がたるかったので今回はしていません。
特にvalidデータを使うとか考えてなかったので、単純にデータを
「学習:テスト = 9:1」にしています。

def get_dataset(texts, labels, bsz):
    data = []
    max_len = 0
    for i, target in enumerate(texts):
        hoge = wakati(target)
        data.append(hoge)
        if max_len < hoge.count(' ') + 1:
            max_len = hoge.count(' ') + 1
    
    print(len(data), max_len)
    
    dataset = []
    for i, target in enumerate(data):
        ids = assign_id(target, max_len)
        fuga = (ids, speaker2idx[labels[i]], target.count(' ') + 1)
        dataset.append(fuga)
        
    return DataLoader(dataset, batch_size=bsz, shuffle=True, drop_last=True)

data_size = len(data)
train_size = int(data_size * 0.9)
bsz = 1
train_loader = get_dataset(data['serif'][:train_size], list(data['speaker'][:train_size]), 1)
test_loader = get_dataset(data['serif'][train_size+1:], list(data['speaker'][train_size+1:]), 1)

モデルの作成

Dropoutを入れるべきかと思いましたが、入れた方が精度が悪かったので抜きました。

class estLSTM(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, label_size):
        super(estLSTM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.core = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, label_size)
        
    def forward(self, serif, length):
        emb = self.embed(serif)
        emb = torch.transpose(emb, 0, 1)
        emb = nn.utils.rnn.pack_padded_sequence(emb, length)
        outputs, h = self.core(emb)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs)[0]
        output = self.linear(outputs[-1])
        return output

学習

ハイパーパラメータは適当です。特に検証もしていません。

use_gpu = torch.cuda.is_available()
embed_size = 256
hidden_size = 256
num_layers = 1
est = estLSTM(embed_size, hidden_size, len(word2idx), num_layers, len(speaker2idx))

if use_gpu:
    est = est.cuda()

lr = 4e-4
cri = nn.CrossEntropyLoss()
params = list(est.parameters())
optimizer = torch.optim.Adam(params, lr = lr)

num_epochs = 10000
best_loss = 10000

for epoch in range(num_epochs):
    epoch_loss = 0
    for i, (text, label, length) in enumerate(train_loader):
        if use_gpu:
            text = text.cuda()
            label = label.cuda()
        text = Variable(text)
        label = Variable(label)
        output = est(text, length)
        
        loss = cri(output, label)
        est.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.data[0]
        
    epoch_loss /= total_step  
    print('Epoch [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
         .format(epoch+1, num_epochs, epoch_loss, np.exp(epoch_loss)))
        
    if best_loss > epoch_loss:
        best_loss = epoch_loss
        with open('speaker-est.pt', 'wb') as f:
            torch.save(est.state_dict(), f)

テスト

手元の環境だと50epochくらいでまぁまぁ収束していました。

PATH = 'speaker-est.pt'
est.load_state_dict(torch.load(PATH, map_location='cpu'))
# テスト
acc = 0
for i, (text, label, length) in enumerate(test_loader):
    text = Variable(text)
    output = est(text, length)
    _, predicted = torch.max(output.data, 1)
    
    if label == predicted:
        acc += 1
        
print('Accurency = {} / {} ({})'.format(acc, i+1, acc / i+1))

考察

Accurency = 22 / 30 (0.7333333)
と73%の正答率を得ました。ランダムだと50%なので、ギリギリ及第点なのではないのでしょうか。
本当は95%くらい欲しいですが
学習に使ったデータ数が2人合わせて271個(他テスト:30個)だったので大分少ないという印象です。
戦闘中なんかは目に見えて、セリフが短くなる(「うわっ」とか)。題材選びがまずかった気がします。

機械学習で画像キャプション (機械学習)

前回まで

前回は、前々回までで収集したツイートを整形していました。

今回

データセットを作成し、機械学習をさせたいと思います。
1. 画像をCNNにかける (image → Tensor)
2. 画像とツイートのデータセットを作成する
3. 文章生成のモデルを作成する
4. 学習をさせ、テスト結果を確認する

参考

https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/sample.pyを参考に作成いたしました。

注意

今回もPython + Pytorch + Jupyterの使用をイメージしています。
importするものが過不足あるかもしれません。

画像をCNNにかける

import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.autograd import Variable
import torchvision
import numpy as np
import pandas as pd

vgg16_feature = models.vgg16(pretrained=True)
for p in vgg16_feature.features.parameters():
    p.requires_grad = False
vgg16_feature.classifier = nn.Sequential(
        nn.Linear(512 * 7 * 7, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096)
    )

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

def extract_feature(image):
    img = Image.open('(画像のあるディレクトリ)' + image)
    img_tensor = preprocess(img)
    img_tensor.unsqueeze_(0)
    return vgg16_feature(Variable(img_tensor))

import time
twiiter_data = pd.read_table('tweet2.csv')
tweet_len = tweet_data.shape[0]

image_feature = []
feature_start_time = time.time()
for i, image in enumerate(tweet_data['image']):
    ex_feature = extract_feature(image)
    try:
        image_feature.append(ex_feature)
        if i % 500 == 0:
            print('({:d}/{:d}) | elapsed time: {:5.2f} | {}'.format(i, tweet_len, (time.time()-feature_start_time), image))
    except:
        print('(', i, '/', tweet_len, ')', image)
        
# 画像特徴量をファイル保存
torch.save(image_feature, 'image_feature.pt')

学習済みのVGG16というモデルの識別層を除去し、画像のTensorを抽出しています。

画像とツイートのデータセットを作成する

ここからは、今までとは別ファイルに記述しています。

import torch
import torch.nn as nn
import torchvision.datasets as dset
from torchvision import models
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pack_padded_sequence
import torch.optim as optim
import torch.utils.data as data

import MeCab
import math
import pickle
import os, time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from collections import Counter
% matplotlib inline

辞書の作成

image_feature = torch.load('image_feature.pt')
print(len(image_feature))

word2idx = {}    # word => id
idx2word = []    # [id] = word
tagger = MeCab.Tagger("-Owakati")
threshold = 3

def add_word(word):
    if word not in word2idx:
        idx2word.append(word)
        word2idx[word] = len(idx2word) - 1

def vocab(word):
    if word not in word2idx:
        return word2idx['<unk>']
    else:
        return word2idx[word]

def wakati(target):
    wakati = tagger.parse(target)
    wakati = wakati.replace('\n', '')
    caption = '<bos> ' + wakati + '<eos>'
    return caption

def assign_id(text, length):
    words = text.split()
    ids = torch.zeros(length).long()
    for token, word in enumerate(words):
        ids[token] = vocab(word)
    return ids

import pandas as pd
tweet_data = pd.read_table('tweet2.csv')
counter = Counter()
for target in tweet_data['text']:
    words = wakati(target)
    words = words.split()
    counter.update(words)
        
words = [word for word, cnt in counter.items() if cnt >= threshold]

add_word('<pad>')
add_word('<bos>')
add_word('<eos>')
add_word('<unk>')

for word in words:
    add_word(word)

辞書の作成(但し、出現回数>=2回のものに限る)
* → 文章の末尾に最大長から足りない分だけ付与する * → 文章の開始記号 * → 文章の終了記号 * → 辞書にない言葉(出現回数が少ない)の代わり

データセットの作成

def collate_fn(data):
    data.sort(key=lambda x: x[2], reverse=True)
    images, captions, length = zip(*data)
    images = torch.stack(images, 0)
    targets = torch.stack(captions, 0)
    length = list(length)
    return images, targets, length

def get_dataset(coco, img_feature, bsz):
    data = []
    max_length = 0
    start_time = time.time()
    for i, target in enumerate(coco):
            caption = wakati(target)
            img_feature_caption = (img_feature[i], caption)
            data.append(img_feature_caption)
            if max_length < caption.count(' ') + 1:
                    max_length = caption.count(' ') + 1

    print('data num = ', len(data), 'max len=', max_length)
    data_set = []
    for img, target in data:
        ids = assign_id(target, max_length)
        img_cap_len = (img.data, ids, target.count(' ') + 1)
        data_set.append(img_cap_len)

    return DataLoader(data_set, batch_size=bsz, shuffle=True, drop_last=True, collate_fn=collate_fn)

bsz = 64
tweet_len = len(image_feature)

train_size = int(tweet_len * 0.8)
valid_size = int(tweet_len * 0.9)
train_dataset = get_dataset(tweet_data['text'][:train_size], image_feature[:train_size], bsz)
val_dataset = get_dataset(tweet_data['text'][train_size+1:valid_size], image_feature[train_size+1:valid_size], bsz)
test_dataset = get_dataset(tweet_data['text'][valid_size+1:], image_feature[valid_size+1:], bsz)

今回はvalidとtestは使いませんが、一応作成しています。

文章生成のモデルを作成する

生成モデルの構築

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.encoder = nn.Linear(4096, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        features = self.encoder(features)
        embeddings = self.embed(captions)
        embeddings = torch.cat((features, embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)
            inputs = inputs.unsqueeze(1)
        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids

各種パラメータの設定

use_gpu = torch.cuda.is_available()
embed_size = 512
hidden_size = 512
num_layers = 1
decoder = DecoderRNN(embed_size, hidden_size, len(word2idx), num_layers)
if use_gpu:
    decoder = decoder.cuda()
learning_rate = 4e-4
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)

学習

num_epochs = 10000
total_step = len(train_dataset)
log_step = 3000
for epoch in range(num_epochs):
    for i, (features, captions, lengths) in enumerate(train_dataset):
        # Set mini-batch dataset
        if use_gpu:
            features = features.cuda()
            captions = captions.cuda()
        features = Variable(features)
        captions = Variable(captions)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        # Forward, backward and optimize
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Print log info
        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0])))

    # Save the model checkpoints
    with open('tw_cap.pt', 'wb') as f:
        torch.save(decoder.state_dict(), f)

epoch数10000とlog_stepは適当です。適せん変更してください。

学習をさせ、テスト結果を確認する

tmp = (適当な画像No.)
feature = image_feature[tmp]
sample_ids = decoder.sample(feature)
sampled_ids = sample_ids[0].cpu().numpy()
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
    word = idx2word[word_id]
    sampled_caption.append(word)
    if word == '<eos>':
        break
sentence = ' '.join(sampled_caption)
    
print (sentence)

dir = '(画像のあるディレクトリ)'
img = Image.open(dir + tweet_list[tmp][1])
plt.imshow(img)

生成例

入力
f:id:obq777:20180621101006p:plain
出力
< bos >< bos >昼もパスタにしよう< eos >< eos >

入力
f:id:obq777:20180621112051p:plain
出力
< bos > < bos > 【 大喜 利 】 海鮮 丼 を 食べる 際 、 他 の 魚 は 豪快 に

失敗しとるやないか
< bos >などの記号が2回生成されてますね。適せん除去すればいい感じに
< bos >と< eos >に囲まれた文章はなかなかいい感じです。
パスタや海鮮丼など、画像の特徴はとれていると思います。

おしまい

ありがとうございました。

機械学習で画像キャプション (データ前処理)

前回のおさらい

前回は学習データとなるツイートを収集したので、今回はデータの前処理を行います。  

注意

前回はRubyを使いましたが、今回からはPythonとJupyterを使います。  

前処理を行なう必要性

  • ニュースなど重複したツイートが存在する
  • 全ての画像をJPEGに統一する
  • 白黒画像を除去する
    特に下2つは、学習の際にTensorにすると、次元数が異なるのでエラーとなります。

前処理

インポート

import pandas as pd
import MeCab
from PIL import Image
import os, sys

重複ツイートの除去

twitter_data = pd.read_table('tweet.csv')
print(twitter_data.shape)
sorted_twitter = twitter_data.sort_values(['text'])
no_duplicated_twitter = sorted_twitter.drop_duplicates(['text', 'image'],  keep='first')
print(no_duplicated_twitter.shape)
no_duplicated_twitter.to_csv('removeDupulicateTweet.csv', index=False, sep='\t')

画像フォーマットの統一

# CSVファイル書き換え
png = []
for line, tweet in enumerate(tweet_data['image']):
    if '.png' in tweet:
        png.append(tweet)
        tweet_data['image'][line] = tweet.replace('.png', '.jpg')
print(len(png))

dir = '(画像のあるディレクトリ)'
for filename in png:
    img = Image.open(dir + filename).convert('RGB')
    jpg_filename = filename.replace('.png', '.jpg')
    img.save(dir + jpg_filename, 'jpeg')

CSVファイルの画像ファイル名を先に変更して、その後画像ファイルのフォーマットを変換しています。

白黒画像の除去

import torch
import torch.nn as nn
from torchvision import transforms
from torch.autograd import Variable
import torchvision
import numpy as np

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

dir = '(画像のあるディレクトリ)'
def img_tensor_size(image):
    img = Image.open(dir + image)
    img_tensor = preprocess(img)
    img_tensor.unsqueeze_(0)
    return Variable(img_tensor).size()

error = []
expected_size = torch.zeros([1, 3, 224, 224]).size()
for line, tweet in enumerate(tweet_data['image']):
    if img_tensor_size(tweet) != expected_size:
        error.append(line)
tweet_data.drop(error, inplace=True)
print(tweet_data.shape)

tweet_data.to_csv('tweet2.csv', index = False, sep='\t')

白黒画像の除去と言いながら、万が一のことも考えて、欲しいTensorのサイズと異なるものだと除去するようにしてます。  

機械学習で行う画像キャプション (データ収集)

画像キャプションとは

画像に説明文を付与するものです。
それを今回は、機械学習で行いたいと思います。

実行例

f:id:obq777:20180621101006p:plain
昼もパスタにしよう

注意

初心者の記事ですので優しくご覧ください

使用するもの

流れ

  1. 学習用の画像と説明文のペアのデータ収集 ← 今回
  2. 収集したデータを前処理
  3. 機械学習

データ収集

画像キャプションのデータといえば、英語だと MSCOCO、日本語だと STAIR Captionsが有名どころです。
今回は軽くサクッとさせたいため、Twitterからデータを収集します。

機械学習Python + Pytorchで行ないますが、 Rubyでツイートを収集します

require 'twitter'
require 'fileutils'
require 'csv'
require 'open-uri'
require 'uri'

# twitterAPI認証
client = Twitter::REST::Client.new do |config|
    config.consumer_key =
    config.consumer_secret = 
    config.access_token = 
    config.access_token_secret = 
end

# 画像保存
def aImageSave(url)
    # tmp_path = "/Users/takeshi/Documents/twitterImages/#{File.basename(url)}"
    tmp_path = "twitterData/twitterImages/#{File.basename(url)}"
    print "#{url} "
    File.open(tmp_path, 'w+b') do |f|
        begin
            # 404でエラー
            f.write open(url).read
        rescue
            return false
        end
    end
    puts "saved"
    return true
end

# URLの除去
def removeUrl(text)
    urls = []
    URI::extract(text, %w{http https}).uniq.each do |uri|
        urls << uri
    end
    urls.each do |uri|
        text = text.delete(uri)
    end
    return text
end

# ハッシュタグの除去
def removeHashtag(text)
    hashtag = text.scan(/[##][A-Za-zA-Za-z一-鿆0-90-9ぁ-ヶヲ-゚ー]+/).map(&:strip)
    hashtag.each do |tag|
        text = text.delete(tag)
    end
    return text
end

# コンマの除去
def removeComma(text)
    return text.gsub!(/,/, '')
end


# 画像URLからベースネームを取得
def removeDomain(imageUrl)
    splitUrl = imageUrl.chomp.split('/')
    return splitUrl[4]
end

# tweet取得
def getTweet(client)
    # 検索ワード
    words = ["ラーメン", "カツ丼", "天丼", "カレー", "牛丼", "海鮮丼", "親子丼", "定食",
        "オムライス", "パスタ", "ピザ", "サラダ", "焼肉", "蕎麦", ""]
    limit = 80   #180/15m
    words.each do |word|
        max_id = 0
        limit.times do |count|
            tcount = 0
            puts "word = #{word}, count = #{count}"
            client.search(word, lang: "ja", result_type: "recent", exclude: "retweets", max_id: max_id)
                .take(100).each do |tweet|
                # 画像を含んでいる場合
                if !tweet.media.empty?
                    # リプライを含めない
                    if tweet.in_reply_to_user_id?
                        next
                    end
                    # 画像1枚だけのツイートに絞る
                    if tweet.media[0].type == 'photo'
                        if tweet.media[1].nil?
                            saveSuc = aImageSave(tweet.media[0].media_url.to_s)
                            if !saveSuc
                                next
                            end
                            text = removeUrl(tweet.text)
                            text = removeHashtag(text)
                            imageBasename = removeDomain(tweet.media[0].media_url.to_s)
                            CSV.open('twitterData/tweet.csv', 'a', :col_sep => "\t") do |csv|
                                csv << [text, imageBasename, word]
                            end
                            puts "#{text} saved"
                            puts tweet.created_at
                            tcount += 1
                        end
                    end
                end
                max_id = tweet.id
            end
            max_id -= 1
            if tcount == 0
                break
            end
            sleep(10)
        end
    end
end

unless FileTest.exist?("twitterData")
    FileUtils.mkdir_p("twitterData")
end

# CSVファイルが存在しない場合
unless FileTest.exist?('twitterData/tweet.csv')
    CSV.open('twitterData/tweet.csv', 'w', :col_sep => "\t") do |csv|
        csv << ["text", "image", "search word"]
    end
end

# 画像フォルダが存在しない場合
unless FileTest.exist?("twitterData/twitterImages")
    FileUtils.mkdir_p("twitterData/twitterImages")
end

getTweet(client)

解説

今回はTwitterから、食べ物系の画像とツイートの1対1のペアを取得しています。
その際に、ツイートからURLとハッシュタグを除去しています。

API制限ですが、15分間で制限を越えないようにsleepメソッドで休み休み動作します。

次回

次回は上記のプログラムで収集したデータの前処理を行ないます。