先日の記事のコードをJulia化した.
using LinearAlgebra function quick_convergence_check(v, v_new, u, u_new, eps) if norm( abs.(v - v_new) ) < eps if norm( abs.(u - u_new) ) < eps return true end end return false end function main(A, r, c ; eps = 1.0E-7) r = vec(r) c = vec(c) # vector r and c should be normalized. rs = sum(r) cs = sum(c) r = r / sum(r) c = c / sum(c) # Initial Step v = copy(r) u = copy(c) cnt = 0 while true # Sinkhorn-knopp Step v_new = r ./ ( A' * u ) u_new = c ./ ( A * v_new ) if quick_convergence_check(v, v_new, u, u_new, eps) break end # Initial Step if cnt > 25000 println("calculation did not converge") exit() end v = v_new u = u_new cnt += 1 end P = diagm( vec(u) ) * A * diagm( vec( v )) return P end A = [1 5 3 6; 9 5 5 6; 1 7 10 9; 2 3 1 4] r = [7 7 4 2] c = [1 2 3 4] P = main(A,r,c) display(P) println("\n", sum(P,dims=1) ) println("\n", sum(P,dims=2) )