強化学習で迷路の最短経路を見つける(Julia)

概要

簡単な迷路の経路探索を強化学習を用いて行うプログラムを書きました。
コードと結果は末尾に記載しています。

背景

強化学習は存在を知っているレベルだったんですが、簡単に実装がてら、原理くらいは分かるように勉強しました。
つまり詳しいことは知らんです。
あと、Juliaの勉強です。Pythonで実装すると、どうしても他のサイトのコードをそのままコピペしてしまうので。

環境

  • Julia v1.0
  • Jupyter

強化学習で迷路を探索するとは

僕の所感ですが、何回も(千回レベル)同じ迷路を探索して良い道を学習するというのが近いです。
目的があり(今回は「ゴールする」+「点数が下がるパネルを踏まない」)、その目的に沿うように最適化していきます。
そのためには、自分が「どの位置」で「どの方向」に行けば良いかという指標が必要です。
その指標を学習によって獲得し、その指標通りに迷路を進むと最短経路で最高得点が得られるということです。

強化学習

今回はQ学習とε-greedy法を用いました。

Q学習 - Wikipedia

ε-greedy法とは簡単に述べると、たまにランダムに迷路を進むというものです。

余談

前述で述べた指標ですが、実装的に述べると(迷路の位置数) * (方向数)のデータ空間と考えると分かりやすいです。
今回の迷路は狭いので、このデータ空間は問題になりませんが、大きい問題になると明らかにやばいです。(マリオとか)
そこで、この指標をニューラルネットワークによって記述するのが「DQN (Deep Q-Network)」というものらしいです。

所感

迷路の探索問題(縦型・横型探索とか)は有名で、何回か実装しようとして面倒になってやめたことがあるのですが、これならできるかなと思いました。
今後はDeep Q-networkを実装するか、JuliaのGUIアプリを作成して迷路をアプリで作成できるようにしたいです。

付録

str_field = 
    """
    #,#,#,#,#,#,#
    #,S,0,0,-10,0,#
    #,0,-10,0,0,0,#
    #,0,-10,0,-10,0,#
    #,0,0,0,-10,0,#
    #,0,-10,0,0,100,#
    #,#,#,#,#,#,#
    """
start_point = (2, 2)
goal_point = (6, 6)
field = map(x -> split(x, ","), split(str_field))
alpha = 0.2
gamma  = 0.9
e_greedy_ratio = 0.2
epochs = 1000
Qvalue = Dict()
function display_field(state, action)
    tmp_field = []
    tmp_field = deepcopy(field)
    tmp_field[state[1]][state[2]] = "@"
    for i in 1:size(tmp_field, 1)
        println(tmp_field[i])
    end
    @show state, action
end
function get_actions(point)
    """ 現在位置から行くことが可能な座標を求める """
    x, y = point
    around_map = [(x, y-1), (x, y+1), (x-1, y), (x+1, y)]
    return [hoge for hoge in around_map if !(field[hoge[1]][hoge[2]] in ["#", "S"])]
end
function get_Qvalue(state, action)
    hoge = (state, action)
    try
        return Qvalue[hoge]
    catch
        return 0.0
    end
end
function choose_action_greedy(state)
    """ greedy法で行動を決める """
    best_actions = []
    max_q_value = -1
    for a in get_actions(state)
        q_value = get_Qvalue(state, a)
        if q_value > max_q_value
            best_actions = [a, ]
            max_q_value = q_value
        elseif q_value == max_q_value
            push!(best_actions, a)
        end
    end
    return rand(best_actions)
end
function choose_action(state)
    """ e-greedy法で行動を決める. """
    if e_greedy_ratio < rand()
        return rand(get_actions(state))
    else
        return choose_action_greedy(state)
    end
end
function set_Qvalue(state, action, q_value)
    hoge = (state, action)
    Qvalue[hoge] = q_value
end
function update_Qvalue(state, action)
    Q_s_a = get_Qvalue(state, action)
    mQ_s_a = maximum([get_Qvalue(action, n_action) for n_action in get_actions(action)])
    r_s_a = parse(Int64, field[action[1]][action[2]])
    # calculate
    q_value = Q_s_a + alpha * ( r_s_a +  gamma * mQ_s_a - Q_s_a)
    # update
    set_Qvalue(state, action, q_value)
    return r_s_a != 0.0
end
function train()
    """ 迷路の探索 """
    state = start_point
    while true
        action = choose_action(state)
        if update_Qvalue(state, action)
            break
        else
            state = action
        end
    end
end
function test()
    state = start_point
    count = 0
    reward = 0
    while true
        count += 1
        try
            reward += parse(Int64, field[state[1]][state[2]])
        catch
            reward += 0
        end
        action = choose_action_greedy(state)
        display_field(state, action)
        if state == goal_point
            @show count, reward
            break
        end
        state = action
    end
end
for i in 1:epochs
    train()
end

test()

結果

  • # -> 壁
  • @ -> 現在位置
  • 数字 -> パネルの得点(100がゴール)
  • S -> スタート地点
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "@", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((2, 2), (2, 3))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "@", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((2, 3), (2, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "@", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((2, 4), (3, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "@", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((3, 4), (4, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "@", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((4, 4), (5, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "@", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((5, 4), (6, 4))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "@", "0", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((6, 4), (6, 5))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "@", "100", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((6, 5), (6, 6))
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
SubString{String}["#", "S", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "0", "#"]
SubString{String}["#", "0", "-10", "0", "-10", "0", "#"]
SubString{String}["#", "0", "0", "0", "-10", "0", "#"]
SubString{String}["#", "0", "-10", "0", "0", "@", "#"]
SubString{String}["#", "#", "#", "#", "#", "#", "#"]
(state, action) = ((6, 6), (6, 5))
(count, reward) = (9, 100)