まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 ボルツマンマシンからのサンプリング

メトロポリス法で,ボルツマンマシンからサンプリングするコードを書きました.

とりあえず,状態xを与えたら,エネルギーを返してくれる関数を用意しておきます.

using Plots
using Random

function energy(x,b,w)
    N = length(b)
    term_1_body = sum( x .* b )
    term_2_body = sum( w[i,j] * x[i] * x[j] for i = 1:N for j = i:N)
    E = term_1_body + term_2_body
    return E
end

function delta_energy(x,y,b,w)
    return energy(y,b,w) - energy(x,b,w)
end

次に標準的なメトロポリス法を実装します.

function sample_BM(b,w;beta=1, max_iter=5000)
    N = length(b) # number of spins
    x = rand([-1,1],N) # initial spins
    
    x_history = []
    for iter = 1:max_iter
        trial_state = copy(x)
        
        # Select one spin
        target_spin_idx = rand(1:N)
        # Flip the spin
        if trial_state[target_spin_idx] == 1 
            #trial_state[target_spin_idx] = 0
            trial_state[target_spin_idx] = -1
        else
            trial_state[target_spin_idx] = 1 
        end
        
        del_ene = delta_energy(trial_state,x,b,w)
        if del_ene < 0
            x = trial_state
        elseif del_ene > 0
            a = rand()
            if a < exp(-beta * del_ene)
                x = trial_state
            end
        end
        if iter % 1000 == 0
            @show iter
            push!(x_history,x)
        end
    end
    
    # you can see the state as 
    # display( reshape(x, (L,M))')
    return x, x_history
end

あとは,相互作用wと磁場bを与えたら動きます. 2D Ising ならこんな感じです.

function define_b_w(M,L,J,H)
    N = M * L
    b = H*ones(N)
    
    w = zeros(N,N)
    for m = 0:M-1
        for l = 1:L
            if l < L
                w[m*L+l, m*L+l+1] = J
            end
            if 1 < l
                w[m*L+l, m*L+l-1] = J
            end
            if 0 < m
                w[m*L+l, m*L+l-L] = J
            end
            if m < M-1
                w[m*L+l, m*L+l+L] = J
            end
        end
    end
    return b,w
end

実行してみます.betaは逆温度で,Jが相互作用の強さ,Hが磁場の強さですね.

M = 40; L = 40; J = 1; H = 0.1; beta=0.1;
b, w = define_b_w(M,L,J,H);

x, x_history = sample_BM(b,w;beta=beta,max_iter=5000)
x = reshape(x, (L,M))';
plt=heatmap(x, aspect_ratio=:equal, axis=false, grid=false, size=(400,400), cbar=false)

ちなみに相互作用を大きくした低温だとこんな感じ.