機械学習で行う画像キャプション (データ収集)

画像キャプションとは

画像に説明文を付与するものです。
それを今回は、機械学習で行いたいと思います。

実行例

f:id:obq777:20180621101006p:plain
昼もパスタにしよう

注意

初心者の記事ですので優しくご覧ください

使用するもの

流れ

  1. 学習用の画像と説明文のペアのデータ収集 ← 今回
  2. 収集したデータを前処理
  3. 機械学習

データ収集

画像キャプションのデータといえば、英語だと MSCOCO、日本語だと STAIR Captionsが有名どころです。
今回は軽くサクッとさせたいため、Twitterからデータを収集します。

機械学習Python + Pytorchで行ないますが、 Rubyでツイートを収集します

require 'twitter'
require 'fileutils'
require 'csv'
require 'open-uri'
require 'uri'

# twitterAPI認証
client = Twitter::REST::Client.new do |config|
    config.consumer_key =
    config.consumer_secret = 
    config.access_token = 
    config.access_token_secret = 
end

# 画像保存
def aImageSave(url)
    # tmp_path = "/Users/takeshi/Documents/twitterImages/#{File.basename(url)}"
    tmp_path = "twitterData/twitterImages/#{File.basename(url)}"
    print "#{url} "
    File.open(tmp_path, 'w+b') do |f|
        begin
            # 404でエラー
            f.write open(url).read
        rescue
            return false
        end
    end
    puts "saved"
    return true
end

# URLの除去
def removeUrl(text)
    urls = []
    URI::extract(text, %w{http https}).uniq.each do |uri|
        urls << uri
    end
    urls.each do |uri|
        text = text.delete(uri)
    end
    return text
end

# ハッシュタグの除去
def removeHashtag(text)
    hashtag = text.scan(/[##][A-Za-zA-Za-z一-鿆0-90-9ぁ-ヶヲ-゚ー]+/).map(&:strip)
    hashtag.each do |tag|
        text = text.delete(tag)
    end
    return text
end

# コンマの除去
def removeComma(text)
    return text.gsub!(/,/, '')
end


# 画像URLからベースネームを取得
def removeDomain(imageUrl)
    splitUrl = imageUrl.chomp.split('/')
    return splitUrl[4]
end

# tweet取得
def getTweet(client)
    # 検索ワード
    words = ["ラーメン", "カツ丼", "天丼", "カレー", "牛丼", "海鮮丼", "親子丼", "定食",
        "オムライス", "パスタ", "ピザ", "サラダ", "焼肉", "蕎麦", ""]
    limit = 80   #180/15m
    words.each do |word|
        max_id = 0
        limit.times do |count|
            tcount = 0
            puts "word = #{word}, count = #{count}"
            client.search(word, lang: "ja", result_type: "recent", exclude: "retweets", max_id: max_id)
                .take(100).each do |tweet|
                # 画像を含んでいる場合
                if !tweet.media.empty?
                    # リプライを含めない
                    if tweet.in_reply_to_user_id?
                        next
                    end
                    # 画像1枚だけのツイートに絞る
                    if tweet.media[0].type == 'photo'
                        if tweet.media[1].nil?
                            saveSuc = aImageSave(tweet.media[0].media_url.to_s)
                            if !saveSuc
                                next
                            end
                            text = removeUrl(tweet.text)
                            text = removeHashtag(text)
                            imageBasename = removeDomain(tweet.media[0].media_url.to_s)
                            CSV.open('twitterData/tweet.csv', 'a', :col_sep => "\t") do |csv|
                                csv << [text, imageBasename, word]
                            end
                            puts "#{text} saved"
                            puts tweet.created_at
                            tcount += 1
                        end
                    end
                end
                max_id = tweet.id
            end
            max_id -= 1
            if tcount == 0
                break
            end
            sleep(10)
        end
    end
end

unless FileTest.exist?("twitterData")
    FileUtils.mkdir_p("twitterData")
end

# CSVファイルが存在しない場合
unless FileTest.exist?('twitterData/tweet.csv')
    CSV.open('twitterData/tweet.csv', 'w', :col_sep => "\t") do |csv|
        csv << ["text", "image", "search word"]
    end
end

# 画像フォルダが存在しない場合
unless FileTest.exist?("twitterData/twitterImages")
    FileUtils.mkdir_p("twitterData/twitterImages")
end

getTweet(client)

解説

今回はTwitterから、食べ物系の画像とツイートの1対1のペアを取得しています。
その際に、ツイートからURLとハッシュタグを除去しています。

API制限ですが、15分間で制限を越えないようにsleepメソッドで休み休み動作します。

次回

次回は上記のプログラムで収集したデータの前処理を行ないます。