機械学習で画像キャプション (機械学習)

前回まで

前回は、前々回までで収集したツイートを整形していました。

今回

データセットを作成し、機械学習をさせたいと思います。
1. 画像をCNNにかける (image → Tensor)
2. 画像とツイートのデータセットを作成する
3. 文章生成のモデルを作成する
4. 学習をさせ、テスト結果を確認する

参考

https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/sample.pyを参考に作成いたしました。

注意

今回もPython + Pytorch + Jupyterの使用をイメージしています。
importするものが過不足あるかもしれません。

画像をCNNにかける

import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.autograd import Variable
import torchvision
import numpy as np
import pandas as pd

vgg16_feature = models.vgg16(pretrained=True)
for p in vgg16_feature.features.parameters():
    p.requires_grad = False
vgg16_feature.classifier = nn.Sequential(
        nn.Linear(512 * 7 * 7, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096)
    )

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

def extract_feature(image):
    img = Image.open('(画像のあるディレクトリ)' + image)
    img_tensor = preprocess(img)
    img_tensor.unsqueeze_(0)
    return vgg16_feature(Variable(img_tensor))

import time
twiiter_data = pd.read_table('tweet2.csv')
tweet_len = tweet_data.shape[0]

image_feature = []
feature_start_time = time.time()
for i, image in enumerate(tweet_data['image']):
    ex_feature = extract_feature(image)
    try:
        image_feature.append(ex_feature)
        if i % 500 == 0:
            print('({:d}/{:d}) | elapsed time: {:5.2f} | {}'.format(i, tweet_len, (time.time()-feature_start_time), image))
    except:
        print('(', i, '/', tweet_len, ')', image)
        
# 画像特徴量をファイル保存
torch.save(image_feature, 'image_feature.pt')

学習済みのVGG16というモデルの識別層を除去し、画像のTensorを抽出しています。

画像とツイートのデータセットを作成する

ここからは、今までとは別ファイルに記述しています。

import torch
import torch.nn as nn
import torchvision.datasets as dset
from torchvision import models
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pack_padded_sequence
import torch.optim as optim
import torch.utils.data as data

import MeCab
import math
import pickle
import os, time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from collections import Counter
% matplotlib inline

辞書の作成

image_feature = torch.load('image_feature.pt')
print(len(image_feature))

word2idx = {}    # word => id
idx2word = []    # [id] = word
tagger = MeCab.Tagger("-Owakati")
threshold = 3

def add_word(word):
    if word not in word2idx:
        idx2word.append(word)
        word2idx[word] = len(idx2word) - 1

def vocab(word):
    if word not in word2idx:
        return word2idx['<unk>']
    else:
        return word2idx[word]

def wakati(target):
    wakati = tagger.parse(target)
    wakati = wakati.replace('\n', '')
    caption = '<bos> ' + wakati + '<eos>'
    return caption

def assign_id(text, length):
    words = text.split()
    ids = torch.zeros(length).long()
    for token, word in enumerate(words):
        ids[token] = vocab(word)
    return ids

import pandas as pd
tweet_data = pd.read_table('tweet2.csv')
counter = Counter()
for target in tweet_data['text']:
    words = wakati(target)
    words = words.split()
    counter.update(words)
        
words = [word for word, cnt in counter.items() if cnt >= threshold]

add_word('<pad>')
add_word('<bos>')
add_word('<eos>')
add_word('<unk>')

for word in words:
    add_word(word)

辞書の作成(但し、出現回数>=2回のものに限る)
* → 文章の末尾に最大長から足りない分だけ付与する * → 文章の開始記号 * → 文章の終了記号 * → 辞書にない言葉(出現回数が少ない)の代わり

データセットの作成

def collate_fn(data):
    data.sort(key=lambda x: x[2], reverse=True)
    images, captions, length = zip(*data)
    images = torch.stack(images, 0)
    targets = torch.stack(captions, 0)
    length = list(length)
    return images, targets, length

def get_dataset(coco, img_feature, bsz):
    data = []
    max_length = 0
    start_time = time.time()
    for i, target in enumerate(coco):
            caption = wakati(target)
            img_feature_caption = (img_feature[i], caption)
            data.append(img_feature_caption)
            if max_length < caption.count(' ') + 1:
                    max_length = caption.count(' ') + 1

    print('data num = ', len(data), 'max len=', max_length)
    data_set = []
    for img, target in data:
        ids = assign_id(target, max_length)
        img_cap_len = (img.data, ids, target.count(' ') + 1)
        data_set.append(img_cap_len)

    return DataLoader(data_set, batch_size=bsz, shuffle=True, drop_last=True, collate_fn=collate_fn)

bsz = 64
tweet_len = len(image_feature)

train_size = int(tweet_len * 0.8)
valid_size = int(tweet_len * 0.9)
train_dataset = get_dataset(tweet_data['text'][:train_size], image_feature[:train_size], bsz)
val_dataset = get_dataset(tweet_data['text'][train_size+1:valid_size], image_feature[train_size+1:valid_size], bsz)
test_dataset = get_dataset(tweet_data['text'][valid_size+1:], image_feature[valid_size+1:], bsz)

今回はvalidとtestは使いませんが、一応作成しています。

文章生成のモデルを作成する

生成モデルの構築

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.encoder = nn.Linear(4096, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        features = self.encoder(features)
        embeddings = self.embed(captions)
        embeddings = torch.cat((features, embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)
            inputs = inputs.unsqueeze(1)
        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids

各種パラメータの設定

use_gpu = torch.cuda.is_available()
embed_size = 512
hidden_size = 512
num_layers = 1
decoder = DecoderRNN(embed_size, hidden_size, len(word2idx), num_layers)
if use_gpu:
    decoder = decoder.cuda()
learning_rate = 4e-4
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)

学習

num_epochs = 10000
total_step = len(train_dataset)
log_step = 3000
for epoch in range(num_epochs):
    for i, (features, captions, lengths) in enumerate(train_dataset):
        # Set mini-batch dataset
        if use_gpu:
            features = features.cuda()
            captions = captions.cuda()
        features = Variable(features)
        captions = Variable(captions)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        # Forward, backward and optimize
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Print log info
        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0])))

    # Save the model checkpoints
    with open('tw_cap.pt', 'wb') as f:
        torch.save(decoder.state_dict(), f)

epoch数10000とlog_stepは適当です。適せん変更してください。

学習をさせ、テスト結果を確認する

tmp = (適当な画像No.)
feature = image_feature[tmp]
sample_ids = decoder.sample(feature)
sampled_ids = sample_ids[0].cpu().numpy()
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
    word = idx2word[word_id]
    sampled_caption.append(word)
    if word == '<eos>':
        break
sentence = ' '.join(sampled_caption)
    
print (sentence)

dir = '(画像のあるディレクトリ)'
img = Image.open(dir + tweet_list[tmp][1])
plt.imshow(img)

生成例

入力
f:id:obq777:20180621101006p:plain
出力
< bos >< bos >昼もパスタにしよう< eos >< eos >

入力
f:id:obq777:20180621112051p:plain
出力
< bos > < bos > 【 大喜 利 】 海鮮 丼 を 食べる 際 、 他 の 魚 は 豪快 に

失敗しとるやないか
< bos >などの記号が2回生成されてますね。適せん除去すればいい感じに
< bos >と< eos >に囲まれた文章はなかなかいい感じです。
パスタや海鮮丼など、画像の特徴はとれていると思います。

おしまい

ありがとうございました。