まぃふぇいばりっと

機械学習やってます.Julia大好きです.勉強したことの殴り書きです.

Julia言語 Sinkhorn Knopp アルゴリズムを実装した.

先日の記事のコードをJulia化した.

using LinearAlgebra

function quick_convergence_check(v, v_new, u, u_new, eps)
    if norm( abs.(v - v_new) ) < eps
        if norm( abs.(u - u_new) ) < eps
            return true
        end
    end
    return false
end

function main(A, r, c ; eps = 1.0E-7)
    r = vec(r)
    c = vec(c)
    # vector r and c should be normalized.
    rs = sum(r)
    cs = sum(c)
    r = r / sum(r)
    c = c / sum(c)

    # Initial Step
    v = copy(r)
    u = copy(c)

    cnt = 0
    while true
        # Sinkhorn-knopp Step
        v_new = r ./ ( A' * u )
        u_new = c ./ ( A * v_new )

        if quick_convergence_check(v, v_new, u, u_new, eps)
            break
        end

        # Initial Step
        if cnt > 25000
            println("calculation did not converge")
            exit()
        end

        v = v_new
        u = u_new
        cnt += 1
    end

    P = diagm( vec(u) ) * A * diagm( vec( v ))

    return P
end

A = [1 5 3 6;
     9 5 5 6;
     1 7 10 9;
     2 3 1 4]
r = [7 7 4 2]
c = [1 2 3 4]

P = main(A,r,c)
display(P)
println("\n", sum(P,dims=1) )
println("\n", sum(P,dims=2) )